1 | //===- EmitC.cpp - EmitC Dialect ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
10 | #include "mlir/Dialect/EmitC/IR/EmitCTraits.h" |
11 | #include "mlir/IR/Builders.h" |
12 | #include "mlir/IR/BuiltinAttributes.h" |
13 | #include "mlir/IR/BuiltinTypes.h" |
14 | #include "mlir/IR/DialectImplementation.h" |
15 | #include "mlir/IR/IRMapping.h" |
16 | #include "mlir/IR/Types.h" |
17 | #include "mlir/Interfaces/FunctionImplementation.h" |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/ADT/StringExtras.h" |
20 | #include "llvm/ADT/TypeSwitch.h" |
21 | #include "llvm/Support/Casting.h" |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::emitc; |
25 | |
26 | #include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc" |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // EmitCDialect |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | void EmitCDialect::initialize() { |
33 | addOperations< |
34 | #define GET_OP_LIST |
35 | #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" |
36 | >(); |
37 | addTypes< |
38 | #define GET_TYPEDEF_LIST |
39 | #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc" |
40 | >(); |
41 | addAttributes< |
42 | #define GET_ATTRDEF_LIST |
43 | #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc" |
44 | >(); |
45 | } |
46 | |
47 | /// Materialize a single constant operation from a given attribute value with |
48 | /// the desired resultant type. |
49 | Operation *EmitCDialect::materializeConstant(OpBuilder &builder, |
50 | Attribute value, Type type, |
51 | Location loc) { |
52 | return builder.create<emitc::ConstantOp>(loc, type, value); |
53 | } |
54 | |
55 | /// Default callback for builders of ops carrying a region. Inserts a yield |
56 | /// without arguments. |
57 | void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) { |
58 | builder.create<emitc::YieldOp>(loc); |
59 | } |
60 | |
61 | bool mlir::emitc::isSupportedEmitCType(Type type) { |
62 | if (llvm::isa<emitc::OpaqueType>(type)) |
63 | return true; |
64 | if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type)) |
65 | return isSupportedEmitCType(ptrType.getPointee()); |
66 | if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) { |
67 | auto elemType = arrayType.getElementType(); |
68 | return !llvm::isa<emitc::ArrayType>(elemType) && |
69 | isSupportedEmitCType(elemType); |
70 | } |
71 | if (type.isIndex()) |
72 | return true; |
73 | if (llvm::isa<IntegerType>(Val: type)) |
74 | return isSupportedIntegerType(type); |
75 | if (llvm::isa<FloatType>(Val: type)) |
76 | return isSupportedFloatType(type); |
77 | if (auto tensorType = llvm::dyn_cast<TensorType>(Val&: type)) { |
78 | if (!tensorType.hasStaticShape()) { |
79 | return false; |
80 | } |
81 | auto elemType = tensorType.getElementType(); |
82 | if (llvm::isa<emitc::ArrayType>(elemType)) { |
83 | return false; |
84 | } |
85 | return isSupportedEmitCType(type: elemType); |
86 | } |
87 | if (auto tupleType = llvm::dyn_cast<TupleType>(type)) { |
88 | return llvm::all_of(tupleType.getTypes(), [](Type type) { |
89 | return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type); |
90 | }); |
91 | } |
92 | return false; |
93 | } |
94 | |
95 | bool mlir::emitc::isSupportedIntegerType(Type type) { |
96 | if (auto intType = llvm::dyn_cast<IntegerType>(type)) { |
97 | switch (intType.getWidth()) { |
98 | case 1: |
99 | case 8: |
100 | case 16: |
101 | case 32: |
102 | case 64: |
103 | return true; |
104 | default: |
105 | return false; |
106 | } |
107 | } |
108 | return false; |
109 | } |
110 | |
111 | bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) { |
112 | return llvm::isa<IndexType, emitc::OpaqueType>(type) || |
113 | isSupportedIntegerType(type); |
114 | } |
115 | |
116 | bool mlir::emitc::isSupportedFloatType(Type type) { |
117 | if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) { |
118 | switch (floatType.getWidth()) { |
119 | case 32: |
120 | case 64: |
121 | return true; |
122 | default: |
123 | return false; |
124 | } |
125 | } |
126 | return false; |
127 | } |
128 | |
129 | /// Check that the type of the initial value is compatible with the operations |
130 | /// result type. |
131 | static LogicalResult verifyInitializationAttribute(Operation *op, |
132 | Attribute value) { |
133 | assert(op->getNumResults() == 1 && "operation must have 1 result" ); |
134 | |
135 | if (llvm::isa<emitc::OpaqueAttr>(value)) |
136 | return success(); |
137 | |
138 | if (llvm::isa<StringAttr>(Val: value)) |
139 | return op->emitOpError() |
140 | << "string attributes are not supported, use #emitc.opaque instead" ; |
141 | |
142 | Type resultType = op->getResult(idx: 0).getType(); |
143 | Type attrType = cast<TypedAttr>(value).getType(); |
144 | |
145 | if (resultType != attrType) |
146 | return op->emitOpError() |
147 | << "requires attribute to either be an #emitc.opaque attribute or " |
148 | "it's type (" |
149 | << attrType << ") to match the op's result type (" << resultType |
150 | << ")" ; |
151 | |
152 | return success(); |
153 | } |
154 | |
155 | //===----------------------------------------------------------------------===// |
156 | // AddOp |
157 | //===----------------------------------------------------------------------===// |
158 | |
159 | LogicalResult AddOp::verify() { |
160 | Type lhsType = getLhs().getType(); |
161 | Type rhsType = getRhs().getType(); |
162 | |
163 | if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType)) |
164 | return emitOpError("requires that at most one operand is a pointer" ); |
165 | |
166 | if ((isa<emitc::PointerType>(lhsType) && |
167 | !isa<IntegerType, emitc::OpaqueType>(rhsType)) || |
168 | (isa<emitc::PointerType>(rhsType) && |
169 | !isa<IntegerType, emitc::OpaqueType>(lhsType))) |
170 | return emitOpError("requires that one operand is an integer or of opaque " |
171 | "type if the other is a pointer" ); |
172 | |
173 | return success(); |
174 | } |
175 | |
176 | //===----------------------------------------------------------------------===// |
177 | // ApplyOp |
178 | //===----------------------------------------------------------------------===// |
179 | |
180 | LogicalResult ApplyOp::verify() { |
181 | StringRef applicableOperatorStr = getApplicableOperator(); |
182 | |
183 | // Applicable operator must not be empty. |
184 | if (applicableOperatorStr.empty()) |
185 | return emitOpError("applicable operator must not be empty" ); |
186 | |
187 | // Only `*` and `&` are supported. |
188 | if (applicableOperatorStr != "&" && applicableOperatorStr != "*" ) |
189 | return emitOpError("applicable operator is illegal" ); |
190 | |
191 | Operation *op = getOperand().getDefiningOp(); |
192 | if (op && dyn_cast<ConstantOp>(op)) |
193 | return emitOpError("cannot apply to constant" ); |
194 | |
195 | return success(); |
196 | } |
197 | |
198 | //===----------------------------------------------------------------------===// |
199 | // AssignOp |
200 | //===----------------------------------------------------------------------===// |
201 | |
202 | /// The assign op requires that the assigned value's type matches the |
203 | /// assigned-to variable type. |
204 | LogicalResult emitc::AssignOp::verify() { |
205 | Value variable = getVar(); |
206 | Operation *variableDef = variable.getDefiningOp(); |
207 | if (!variableDef || |
208 | !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef)) |
209 | return emitOpError() << "requires first operand (" << variable |
210 | << ") to be a Variable or subscript" ; |
211 | |
212 | Value value = getValue(); |
213 | if (variable.getType() != value.getType()) |
214 | return emitOpError() << "requires value's type (" << value.getType() |
215 | << ") to match variable's type (" << variable.getType() |
216 | << ")" ; |
217 | if (isa<ArrayType>(variable.getType())) |
218 | return emitOpError() << "cannot assign to array type" ; |
219 | return success(); |
220 | } |
221 | |
222 | //===----------------------------------------------------------------------===// |
223 | // CastOp |
224 | //===----------------------------------------------------------------------===// |
225 | |
226 | bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
227 | Type input = inputs.front(), output = outputs.front(); |
228 | |
229 | return ((llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType, |
230 | emitc::PointerType>(input)) && |
231 | (llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType, |
232 | emitc::PointerType>(output))); |
233 | } |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // CallOp |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | LogicalResult emitc::CallOpaqueOp::verify() { |
240 | // Callee must not be empty. |
241 | if (getCallee().empty()) |
242 | return emitOpError("callee must not be empty" ); |
243 | |
244 | if (std::optional<ArrayAttr> argsAttr = getArgs()) { |
245 | for (Attribute arg : *argsAttr) { |
246 | auto intAttr = llvm::dyn_cast<IntegerAttr>(arg); |
247 | if (intAttr && llvm::isa<IndexType>(intAttr.getType())) { |
248 | int64_t index = intAttr.getInt(); |
249 | // Args with elements of type index must be in range |
250 | // [0..operands.size). |
251 | if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands()))) |
252 | return emitOpError("index argument is out of range" ); |
253 | |
254 | // Args with elements of type ArrayAttr must have a type. |
255 | } else if (llvm::isa<ArrayAttr>( |
256 | arg) /*&& llvm::isa<NoneType>(arg.getType())*/) { |
257 | // FIXME: Array attributes never have types |
258 | return emitOpError("array argument has no type" ); |
259 | } |
260 | } |
261 | } |
262 | |
263 | if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) { |
264 | for (Attribute tArg : *templateArgsAttr) { |
265 | if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg)) |
266 | return emitOpError("template argument has invalid type" ); |
267 | } |
268 | } |
269 | |
270 | if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) { |
271 | return emitOpError() << "cannot return array type" ; |
272 | } |
273 | |
274 | return success(); |
275 | } |
276 | |
277 | //===----------------------------------------------------------------------===// |
278 | // ConstantOp |
279 | //===----------------------------------------------------------------------===// |
280 | |
281 | LogicalResult emitc::ConstantOp::verify() { |
282 | Attribute value = getValueAttr(); |
283 | if (failed(verifyInitializationAttribute(getOperation(), value))) |
284 | return failure(); |
285 | if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) { |
286 | if (opaqueValue.getValue().empty()) |
287 | return emitOpError() << "value must not be empty" ; |
288 | } |
289 | return success(); |
290 | } |
291 | |
292 | OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } |
293 | |
294 | //===----------------------------------------------------------------------===// |
295 | // ExpressionOp |
296 | //===----------------------------------------------------------------------===// |
297 | |
298 | Operation *ExpressionOp::getRootOp() { |
299 | auto yieldOp = cast<YieldOp>(getBody()->getTerminator()); |
300 | Value yieldedValue = yieldOp.getResult(); |
301 | Operation *rootOp = yieldedValue.getDefiningOp(); |
302 | assert(rootOp && "Yielded value not defined within expression" ); |
303 | return rootOp; |
304 | } |
305 | |
306 | LogicalResult ExpressionOp::verify() { |
307 | Type resultType = getResult().getType(); |
308 | Region ®ion = getRegion(); |
309 | |
310 | Block &body = region.front(); |
311 | |
312 | if (!body.mightHaveTerminator()) |
313 | return emitOpError("must yield a value at termination" ); |
314 | |
315 | auto yield = cast<YieldOp>(body.getTerminator()); |
316 | Value yieldResult = yield.getResult(); |
317 | |
318 | if (!yieldResult) |
319 | return emitOpError("must yield a value at termination" ); |
320 | |
321 | Type yieldType = yieldResult.getType(); |
322 | |
323 | if (resultType != yieldType) |
324 | return emitOpError("requires yielded type to match return type" ); |
325 | |
326 | for (Operation &op : region.front().without_terminator()) { |
327 | if (!op.hasTrait<OpTrait::emitc::CExpression>()) |
328 | return emitOpError("contains an unsupported operation" ); |
329 | if (op.getNumResults() != 1) |
330 | return emitOpError("requires exactly one result for each operation" ); |
331 | if (!op.getResult(0).hasOneUse()) |
332 | return emitOpError("requires exactly one use for each operation" ); |
333 | } |
334 | |
335 | return success(); |
336 | } |
337 | |
338 | //===----------------------------------------------------------------------===// |
339 | // ForOp |
340 | //===----------------------------------------------------------------------===// |
341 | |
342 | void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, |
343 | Value ub, Value step, BodyBuilderFn bodyBuilder) { |
344 | OpBuilder::InsertionGuard g(builder); |
345 | result.addOperands({lb, ub, step}); |
346 | Type t = lb.getType(); |
347 | Region *bodyRegion = result.addRegion(); |
348 | Block *bodyBlock = builder.createBlock(bodyRegion); |
349 | bodyBlock->addArgument(t, result.location); |
350 | |
351 | // Create the default terminator if the builder is not provided. |
352 | if (!bodyBuilder) { |
353 | ForOp::ensureTerminator(*bodyRegion, builder, result.location); |
354 | } else { |
355 | OpBuilder::InsertionGuard guard(builder); |
356 | builder.setInsertionPointToStart(bodyBlock); |
357 | bodyBuilder(builder, result.location, bodyBlock->getArgument(0)); |
358 | } |
359 | } |
360 | |
361 | void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {} |
362 | |
363 | ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { |
364 | Builder &builder = parser.getBuilder(); |
365 | Type type; |
366 | |
367 | OpAsmParser::Argument inductionVariable; |
368 | OpAsmParser::UnresolvedOperand lb, ub, step; |
369 | |
370 | // Parse the induction variable followed by '='. |
371 | if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || |
372 | // Parse loop bounds. |
373 | parser.parseOperand(lb) || parser.parseKeyword("to" ) || |
374 | parser.parseOperand(ub) || parser.parseKeyword("step" ) || |
375 | parser.parseOperand(step)) |
376 | return failure(); |
377 | |
378 | // Parse the optional initial iteration arguments. |
379 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
380 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
381 | regionArgs.push_back(inductionVariable); |
382 | |
383 | // Parse optional type, else assume Index. |
384 | if (parser.parseOptionalColon()) |
385 | type = builder.getIndexType(); |
386 | else if (parser.parseType(type)) |
387 | return failure(); |
388 | |
389 | // Resolve input operands. |
390 | regionArgs.front().type = type; |
391 | if (parser.resolveOperand(lb, type, result.operands) || |
392 | parser.resolveOperand(ub, type, result.operands) || |
393 | parser.resolveOperand(step, type, result.operands)) |
394 | return failure(); |
395 | |
396 | // Parse the body region. |
397 | Region *body = result.addRegion(); |
398 | if (parser.parseRegion(*body, regionArgs)) |
399 | return failure(); |
400 | |
401 | ForOp::ensureTerminator(*body, builder, result.location); |
402 | |
403 | // Parse the optional attribute list. |
404 | if (parser.parseOptionalAttrDict(result.attributes)) |
405 | return failure(); |
406 | |
407 | return success(); |
408 | } |
409 | |
410 | void ForOp::print(OpAsmPrinter &p) { |
411 | p << " " << getInductionVar() << " = " << getLowerBound() << " to " |
412 | << getUpperBound() << " step " << getStep(); |
413 | |
414 | p << ' '; |
415 | if (Type t = getInductionVar().getType(); !t.isIndex()) |
416 | p << " : " << t << ' '; |
417 | p.printRegion(getRegion(), |
418 | /*printEntryBlockArgs=*/false, |
419 | /*printBlockTerminators=*/false); |
420 | p.printOptionalAttrDict((*this)->getAttrs()); |
421 | } |
422 | |
423 | LogicalResult ForOp::verifyRegions() { |
424 | // Check that the body defines as single block argument for the induction |
425 | // variable. |
426 | if (getInductionVar().getType() != getLowerBound().getType()) |
427 | return emitOpError( |
428 | "expected induction variable to be same type as bounds and step" ); |
429 | |
430 | return success(); |
431 | } |
432 | |
433 | //===----------------------------------------------------------------------===// |
434 | // CallOp |
435 | //===----------------------------------------------------------------------===// |
436 | |
437 | LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
438 | // Check that the callee attribute was specified. |
439 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee" ); |
440 | if (!fnAttr) |
441 | return emitOpError("requires a 'callee' symbol reference attribute" ); |
442 | FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
443 | if (!fn) |
444 | return emitOpError() << "'" << fnAttr.getValue() |
445 | << "' does not reference a valid function" ; |
446 | |
447 | // Verify that the operand and result types match the callee. |
448 | auto fnType = fn.getFunctionType(); |
449 | if (fnType.getNumInputs() != getNumOperands()) |
450 | return emitOpError("incorrect number of operands for callee" ); |
451 | |
452 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
453 | if (getOperand(i).getType() != fnType.getInput(i)) |
454 | return emitOpError("operand type mismatch: expected operand type " ) |
455 | << fnType.getInput(i) << ", but provided " |
456 | << getOperand(i).getType() << " for operand number " << i; |
457 | |
458 | if (fnType.getNumResults() != getNumResults()) |
459 | return emitOpError("incorrect number of results for callee" ); |
460 | |
461 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
462 | if (getResult(i).getType() != fnType.getResult(i)) { |
463 | auto diag = emitOpError("result type mismatch at index " ) << i; |
464 | diag.attachNote() << " op result types: " << getResultTypes(); |
465 | diag.attachNote() << "function result types: " << fnType.getResults(); |
466 | return diag; |
467 | } |
468 | |
469 | return success(); |
470 | } |
471 | |
472 | FunctionType CallOp::getCalleeType() { |
473 | return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
474 | } |
475 | |
476 | //===----------------------------------------------------------------------===// |
477 | // DeclareFuncOp |
478 | //===----------------------------------------------------------------------===// |
479 | |
480 | LogicalResult |
481 | DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
482 | // Check that the sym_name attribute was specified. |
483 | auto fnAttr = getSymNameAttr(); |
484 | if (!fnAttr) |
485 | return emitOpError("requires a 'sym_name' symbol reference attribute" ); |
486 | FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
487 | if (!fn) |
488 | return emitOpError() << "'" << fnAttr.getValue() |
489 | << "' does not reference a valid function" ; |
490 | |
491 | return success(); |
492 | } |
493 | |
494 | //===----------------------------------------------------------------------===// |
495 | // FuncOp |
496 | //===----------------------------------------------------------------------===// |
497 | |
498 | void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
499 | FunctionType type, ArrayRef<NamedAttribute> attrs, |
500 | ArrayRef<DictionaryAttr> argAttrs) { |
501 | state.addAttribute(SymbolTable::getSymbolAttrName(), |
502 | builder.getStringAttr(name)); |
503 | state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); |
504 | state.attributes.append(attrs.begin(), attrs.end()); |
505 | state.addRegion(); |
506 | |
507 | if (argAttrs.empty()) |
508 | return; |
509 | assert(type.getNumInputs() == argAttrs.size()); |
510 | function_interface_impl::addArgAndResultAttrs( |
511 | builder, state, argAttrs, /*resultAttrs=*/std::nullopt, |
512 | getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); |
513 | } |
514 | |
515 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
516 | auto buildFuncType = |
517 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
518 | function_interface_impl::VariadicFlag, |
519 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
520 | |
521 | return function_interface_impl::parseFunctionOp( |
522 | parser, result, /*allowVariadic=*/false, |
523 | getFunctionTypeAttrName(result.name), buildFuncType, |
524 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
525 | } |
526 | |
527 | void FuncOp::print(OpAsmPrinter &p) { |
528 | function_interface_impl::printFunctionOp( |
529 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
530 | getArgAttrsAttrName(), getResAttrsAttrName()); |
531 | } |
532 | |
533 | LogicalResult FuncOp::verify() { |
534 | if (getNumResults() > 1) |
535 | return emitOpError("requires zero or exactly one result, but has " ) |
536 | << getNumResults(); |
537 | |
538 | if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0])) |
539 | return emitOpError("cannot return array type" ); |
540 | |
541 | return success(); |
542 | } |
543 | |
544 | //===----------------------------------------------------------------------===// |
545 | // ReturnOp |
546 | //===----------------------------------------------------------------------===// |
547 | |
548 | LogicalResult ReturnOp::verify() { |
549 | auto function = cast<FuncOp>((*this)->getParentOp()); |
550 | |
551 | // The operand number and types must match the function signature. |
552 | if (getNumOperands() != function.getNumResults()) |
553 | return emitOpError("has " ) |
554 | << getNumOperands() << " operands, but enclosing function (@" |
555 | << function.getName() << ") returns " << function.getNumResults(); |
556 | |
557 | if (function.getNumResults() == 1) |
558 | if (getOperand().getType() != function.getResultTypes()[0]) |
559 | return emitError() << "type of the return operand (" |
560 | << getOperand().getType() |
561 | << ") doesn't match function result type (" |
562 | << function.getResultTypes()[0] << ")" |
563 | << " in function @" << function.getName(); |
564 | return success(); |
565 | } |
566 | |
567 | //===----------------------------------------------------------------------===// |
568 | // IfOp |
569 | //===----------------------------------------------------------------------===// |
570 | |
571 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
572 | bool addThenBlock, bool addElseBlock) { |
573 | assert((!addElseBlock || addThenBlock) && |
574 | "must not create else block w/o then block" ); |
575 | result.addOperands(cond); |
576 | |
577 | // Add regions and blocks. |
578 | OpBuilder::InsertionGuard guard(builder); |
579 | Region *thenRegion = result.addRegion(); |
580 | if (addThenBlock) |
581 | builder.createBlock(thenRegion); |
582 | Region *elseRegion = result.addRegion(); |
583 | if (addElseBlock) |
584 | builder.createBlock(elseRegion); |
585 | } |
586 | |
587 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
588 | bool withElseRegion) { |
589 | result.addOperands(cond); |
590 | |
591 | // Build then region. |
592 | OpBuilder::InsertionGuard guard(builder); |
593 | Region *thenRegion = result.addRegion(); |
594 | builder.createBlock(thenRegion); |
595 | |
596 | // Build else region. |
597 | Region *elseRegion = result.addRegion(); |
598 | if (withElseRegion) { |
599 | builder.createBlock(elseRegion); |
600 | } |
601 | } |
602 | |
603 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
604 | function_ref<void(OpBuilder &, Location)> thenBuilder, |
605 | function_ref<void(OpBuilder &, Location)> elseBuilder) { |
606 | assert(thenBuilder && "the builder callback for 'then' must be present" ); |
607 | result.addOperands(cond); |
608 | |
609 | // Build then region. |
610 | OpBuilder::InsertionGuard guard(builder); |
611 | Region *thenRegion = result.addRegion(); |
612 | builder.createBlock(thenRegion); |
613 | thenBuilder(builder, result.location); |
614 | |
615 | // Build else region. |
616 | Region *elseRegion = result.addRegion(); |
617 | if (elseBuilder) { |
618 | builder.createBlock(elseRegion); |
619 | elseBuilder(builder, result.location); |
620 | } |
621 | } |
622 | |
623 | ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { |
624 | // Create the regions for 'then'. |
625 | result.regions.reserve(2); |
626 | Region *thenRegion = result.addRegion(); |
627 | Region *elseRegion = result.addRegion(); |
628 | |
629 | Builder &builder = parser.getBuilder(); |
630 | OpAsmParser::UnresolvedOperand cond; |
631 | Type i1Type = builder.getIntegerType(1); |
632 | if (parser.parseOperand(cond) || |
633 | parser.resolveOperand(cond, i1Type, result.operands)) |
634 | return failure(); |
635 | // Parse the 'then' region. |
636 | if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) |
637 | return failure(); |
638 | IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); |
639 | |
640 | // If we find an 'else' keyword then parse the 'else' region. |
641 | if (!parser.parseOptionalKeyword("else" )) { |
642 | if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) |
643 | return failure(); |
644 | IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); |
645 | } |
646 | |
647 | // Parse the optional attribute list. |
648 | if (parser.parseOptionalAttrDict(result.attributes)) |
649 | return failure(); |
650 | return success(); |
651 | } |
652 | |
653 | void IfOp::print(OpAsmPrinter &p) { |
654 | bool printBlockTerminators = false; |
655 | |
656 | p << " " << getCondition(); |
657 | p << ' '; |
658 | p.printRegion(getThenRegion(), |
659 | /*printEntryBlockArgs=*/false, |
660 | /*printBlockTerminators=*/printBlockTerminators); |
661 | |
662 | // Print the 'else' regions if it exists and has a block. |
663 | Region &elseRegion = getElseRegion(); |
664 | if (!elseRegion.empty()) { |
665 | p << " else " ; |
666 | p.printRegion(elseRegion, |
667 | /*printEntryBlockArgs=*/false, |
668 | /*printBlockTerminators=*/printBlockTerminators); |
669 | } |
670 | |
671 | p.printOptionalAttrDict((*this)->getAttrs()); |
672 | } |
673 | |
674 | /// Given the region at `index`, or the parent operation if `index` is None, |
675 | /// return the successor regions. These are the regions that may be selected |
676 | /// during the flow of control. `operands` is a set of optional attributes that |
677 | /// correspond to a constant value for each operand, or null if that operand is |
678 | /// not a constant. |
679 | void IfOp::getSuccessorRegions(RegionBranchPoint point, |
680 | SmallVectorImpl<RegionSuccessor> ®ions) { |
681 | // The `then` and the `else` region branch back to the parent operation. |
682 | if (!point.isParent()) { |
683 | regions.push_back(RegionSuccessor()); |
684 | return; |
685 | } |
686 | |
687 | regions.push_back(RegionSuccessor(&getThenRegion())); |
688 | |
689 | // Don't consider the else region if it is empty. |
690 | Region *elseRegion = &this->getElseRegion(); |
691 | if (elseRegion->empty()) |
692 | regions.push_back(RegionSuccessor()); |
693 | else |
694 | regions.push_back(RegionSuccessor(elseRegion)); |
695 | } |
696 | |
697 | void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, |
698 | SmallVectorImpl<RegionSuccessor> ®ions) { |
699 | FoldAdaptor adaptor(operands, *this); |
700 | auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); |
701 | if (!boolAttr || boolAttr.getValue()) |
702 | regions.emplace_back(&getThenRegion()); |
703 | |
704 | // If the else region is empty, execution continues after the parent op. |
705 | if (!boolAttr || !boolAttr.getValue()) { |
706 | if (!getElseRegion().empty()) |
707 | regions.emplace_back(&getElseRegion()); |
708 | else |
709 | regions.emplace_back(); |
710 | } |
711 | } |
712 | |
713 | void IfOp::getRegionInvocationBounds( |
714 | ArrayRef<Attribute> operands, |
715 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
716 | if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) { |
717 | // If the condition is known, then one region is known to be executed once |
718 | // and the other zero times. |
719 | invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); |
720 | invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); |
721 | } else { |
722 | // Non-constant condition. Each region may be executed 0 or 1 times. |
723 | invocationBounds.assign(2, {0, 1}); |
724 | } |
725 | } |
726 | |
727 | //===----------------------------------------------------------------------===// |
728 | // IncludeOp |
729 | //===----------------------------------------------------------------------===// |
730 | |
731 | void IncludeOp::print(OpAsmPrinter &p) { |
732 | bool standardInclude = getIsStandardInclude(); |
733 | |
734 | p << " " ; |
735 | if (standardInclude) |
736 | p << "<" ; |
737 | p << "\"" << getInclude() << "\"" ; |
738 | if (standardInclude) |
739 | p << ">" ; |
740 | } |
741 | |
742 | ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) { |
743 | bool standardInclude = !parser.parseOptionalLess(); |
744 | |
745 | StringAttr include; |
746 | OptionalParseResult includeParseResult = |
747 | parser.parseOptionalAttribute(include, "include" , result.attributes); |
748 | if (!includeParseResult.has_value()) |
749 | return parser.emitError(parser.getNameLoc()) << "expected string attribute" ; |
750 | |
751 | if (standardInclude && parser.parseOptionalGreater()) |
752 | return parser.emitError(parser.getNameLoc()) |
753 | << "expected trailing '>' for standard include" ; |
754 | |
755 | if (standardInclude) |
756 | result.addAttribute("is_standard_include" , |
757 | UnitAttr::get(parser.getContext())); |
758 | |
759 | return success(); |
760 | } |
761 | |
762 | //===----------------------------------------------------------------------===// |
763 | // LiteralOp |
764 | //===----------------------------------------------------------------------===// |
765 | |
766 | /// The literal op requires a non-empty value. |
767 | LogicalResult emitc::LiteralOp::verify() { |
768 | if (getValue().empty()) |
769 | return emitOpError() << "value must not be empty" ; |
770 | return success(); |
771 | } |
772 | //===----------------------------------------------------------------------===// |
773 | // SubOp |
774 | //===----------------------------------------------------------------------===// |
775 | |
776 | LogicalResult SubOp::verify() { |
777 | Type lhsType = getLhs().getType(); |
778 | Type rhsType = getRhs().getType(); |
779 | Type resultType = getResult().getType(); |
780 | |
781 | if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType)) |
782 | return emitOpError("rhs can only be a pointer if lhs is a pointer" ); |
783 | |
784 | if (isa<emitc::PointerType>(lhsType) && |
785 | !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType)) |
786 | return emitOpError("requires that rhs is an integer, pointer or of opaque " |
787 | "type if lhs is a pointer" ); |
788 | |
789 | if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) && |
790 | !isa<IntegerType, emitc::OpaqueType>(resultType)) |
791 | return emitOpError("requires that the result is an integer or of opaque " |
792 | "type if lhs and rhs are pointers" ); |
793 | return success(); |
794 | } |
795 | |
796 | //===----------------------------------------------------------------------===// |
797 | // VariableOp |
798 | //===----------------------------------------------------------------------===// |
799 | |
800 | LogicalResult emitc::VariableOp::verify() { |
801 | return verifyInitializationAttribute(getOperation(), getValueAttr()); |
802 | } |
803 | |
804 | //===----------------------------------------------------------------------===// |
805 | // YieldOp |
806 | //===----------------------------------------------------------------------===// |
807 | |
808 | LogicalResult emitc::YieldOp::verify() { |
809 | Value result = getResult(); |
810 | Operation *containingOp = getOperation()->getParentOp(); |
811 | |
812 | if (result && containingOp->getNumResults() != 1) |
813 | return emitOpError() << "yields a value not returned by parent" ; |
814 | |
815 | if (!result && containingOp->getNumResults() != 0) |
816 | return emitOpError() << "does not yield a value to be returned by parent" ; |
817 | |
818 | return success(); |
819 | } |
820 | |
821 | //===----------------------------------------------------------------------===// |
822 | // SubscriptOp |
823 | //===----------------------------------------------------------------------===// |
824 | |
825 | LogicalResult emitc::SubscriptOp::verify() { |
826 | // Checks for array operand. |
827 | if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) { |
828 | // Check number of indices. |
829 | if (getIndices().size() != (size_t)arrayType.getRank()) { |
830 | return emitOpError() << "on array operand requires number of indices (" |
831 | << getIndices().size() |
832 | << ") to match the rank of the array type (" |
833 | << arrayType.getRank() << ")" ; |
834 | } |
835 | // Check types of index operands. |
836 | for (unsigned i = 0, e = getIndices().size(); i != e; ++i) { |
837 | Type type = getIndices()[i].getType(); |
838 | if (!isIntegerIndexOrOpaqueType(type)) { |
839 | return emitOpError() << "on array operand requires index operand " << i |
840 | << " to be integer-like, but got " << type; |
841 | } |
842 | } |
843 | // Check element type. |
844 | Type elementType = arrayType.getElementType(); |
845 | if (elementType != getType()) { |
846 | return emitOpError() << "on array operand requires element type (" |
847 | << elementType << ") and result type (" << getType() |
848 | << ") to match" ; |
849 | } |
850 | return success(); |
851 | } |
852 | |
853 | // Checks for pointer operand. |
854 | if (auto pointerType = |
855 | llvm::dyn_cast<emitc::PointerType>(getValue().getType())) { |
856 | // Check number of indices. |
857 | if (getIndices().size() != 1) { |
858 | return emitOpError() |
859 | << "on pointer operand requires one index operand, but got " |
860 | << getIndices().size(); |
861 | } |
862 | // Check types of index operand. |
863 | Type type = getIndices()[0].getType(); |
864 | if (!isIntegerIndexOrOpaqueType(type)) { |
865 | return emitOpError() << "on pointer operand requires index operand to be " |
866 | "integer-like, but got " |
867 | << type; |
868 | } |
869 | // Check pointee type. |
870 | Type pointeeType = pointerType.getPointee(); |
871 | if (pointeeType != getType()) { |
872 | return emitOpError() << "on pointer operand requires pointee type (" |
873 | << pointeeType << ") and result type (" << getType() |
874 | << ") to match" ; |
875 | } |
876 | return success(); |
877 | } |
878 | |
879 | // The operand has opaque type, so we can't assume anything about the number |
880 | // or types of index operands. |
881 | return success(); |
882 | } |
883 | |
884 | //===----------------------------------------------------------------------===// |
885 | // EmitC Enums |
886 | //===----------------------------------------------------------------------===// |
887 | |
888 | #include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc" |
889 | |
890 | //===----------------------------------------------------------------------===// |
891 | // EmitC Attributes |
892 | //===----------------------------------------------------------------------===// |
893 | |
894 | #define GET_ATTRDEF_CLASSES |
895 | #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc" |
896 | |
897 | //===----------------------------------------------------------------------===// |
898 | // EmitC Types |
899 | //===----------------------------------------------------------------------===// |
900 | |
901 | #define GET_TYPEDEF_CLASSES |
902 | #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc" |
903 | |
904 | //===----------------------------------------------------------------------===// |
905 | // ArrayType |
906 | //===----------------------------------------------------------------------===// |
907 | |
908 | Type emitc::ArrayType::parse(AsmParser &parser) { |
909 | if (parser.parseLess()) |
910 | return Type(); |
911 | |
912 | SmallVector<int64_t, 4> dimensions; |
913 | if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false, |
914 | /*withTrailingX=*/true)) |
915 | return Type(); |
916 | // Parse the element type. |
917 | auto typeLoc = parser.getCurrentLocation(); |
918 | Type elementType; |
919 | if (parser.parseType(elementType)) |
920 | return Type(); |
921 | |
922 | // Check that array is formed from allowed types. |
923 | if (!isValidElementType(elementType)) |
924 | return parser.emitError(typeLoc, "invalid array element type" ), Type(); |
925 | if (parser.parseGreater()) |
926 | return Type(); |
927 | return parser.getChecked<ArrayType>(dimensions, elementType); |
928 | } |
929 | |
930 | void emitc::ArrayType::print(AsmPrinter &printer) const { |
931 | printer << "<" ; |
932 | for (int64_t dim : getShape()) { |
933 | printer << dim << 'x'; |
934 | } |
935 | printer.printType(getElementType()); |
936 | printer << ">" ; |
937 | } |
938 | |
939 | LogicalResult emitc::ArrayType::verify( |
940 | ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, |
941 | ::llvm::ArrayRef<int64_t> shape, Type elementType) { |
942 | if (shape.empty()) |
943 | return emitError() << "shape must not be empty" ; |
944 | |
945 | for (int64_t dim : shape) { |
946 | if (dim <= 0) |
947 | return emitError() << "dimensions must have positive size" ; |
948 | } |
949 | |
950 | if (!elementType) |
951 | return emitError() << "element type must not be none" ; |
952 | |
953 | if (!isValidElementType(elementType)) |
954 | return emitError() << "invalid array element type" ; |
955 | |
956 | return success(); |
957 | } |
958 | |
959 | emitc::ArrayType |
960 | emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape, |
961 | Type elementType) const { |
962 | if (!shape) |
963 | return emitc::ArrayType::get(getShape(), elementType); |
964 | return emitc::ArrayType::get(*shape, elementType); |
965 | } |
966 | |
967 | //===----------------------------------------------------------------------===// |
968 | // OpaqueType |
969 | //===----------------------------------------------------------------------===// |
970 | |
971 | LogicalResult mlir::emitc::OpaqueType::verify( |
972 | llvm::function_ref<mlir::InFlightDiagnostic()> emitError, |
973 | llvm::StringRef value) { |
974 | if (value.empty()) { |
975 | return emitError() << "expected non empty string in !emitc.opaque type" ; |
976 | } |
977 | if (value.back() == '*') { |
978 | return emitError() << "pointer not allowed as outer type with " |
979 | "!emitc.opaque, use !emitc.ptr instead" ; |
980 | } |
981 | return success(); |
982 | } |
983 | |
984 | //===----------------------------------------------------------------------===// |
985 | // GlobalOp |
986 | //===----------------------------------------------------------------------===// |
987 | static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, |
988 | TypeAttr type, |
989 | Attribute initialValue) { |
990 | p << type; |
991 | if (initialValue) { |
992 | p << " = " ; |
993 | p.printAttributeWithoutType(attr: initialValue); |
994 | } |
995 | } |
996 | |
997 | static Type getInitializerTypeForGlobal(Type type) { |
998 | if (auto array = llvm::dyn_cast<ArrayType>(type)) |
999 | return RankedTensorType::get(array.getShape(), array.getElementType()); |
1000 | return type; |
1001 | } |
1002 | |
1003 | static ParseResult |
1004 | parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, |
1005 | Attribute &initialValue) { |
1006 | Type type; |
1007 | if (parser.parseType(result&: type)) |
1008 | return failure(); |
1009 | |
1010 | typeAttr = TypeAttr::get(type); |
1011 | |
1012 | if (parser.parseOptionalEqual()) |
1013 | return success(); |
1014 | |
1015 | if (parser.parseAttribute(result&: initialValue, type: getInitializerTypeForGlobal(type))) |
1016 | return failure(); |
1017 | |
1018 | if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>( |
1019 | initialValue)) |
1020 | return parser.emitError(loc: parser.getNameLoc()) |
1021 | << "initial value should be a integer, float, elements or opaque " |
1022 | "attribute" ; |
1023 | return success(); |
1024 | } |
1025 | |
1026 | LogicalResult GlobalOp::verify() { |
1027 | if (!isSupportedEmitCType(getType())) { |
1028 | return emitOpError("expected valid emitc type" ); |
1029 | } |
1030 | if (getInitialValue().has_value()) { |
1031 | Attribute initValue = getInitialValue().value(); |
1032 | // Check that the type of the initial value is compatible with the type of |
1033 | // the global variable. |
1034 | if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) { |
1035 | auto arrayType = llvm::dyn_cast<ArrayType>(getType()); |
1036 | if (!arrayType) |
1037 | return emitOpError("expected array type, but got " ) << getType(); |
1038 | |
1039 | Type initType = elementsAttr.getType(); |
1040 | Type tensorType = getInitializerTypeForGlobal(getType()); |
1041 | if (initType != tensorType) { |
1042 | return emitOpError("initial value expected to be of type " ) |
1043 | << getType() << ", but was of type " << initType; |
1044 | } |
1045 | } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) { |
1046 | if (intAttr.getType() != getType()) { |
1047 | return emitOpError("initial value expected to be of type " ) |
1048 | << getType() << ", but was of type " << intAttr.getType(); |
1049 | } |
1050 | } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) { |
1051 | if (floatAttr.getType() != getType()) { |
1052 | return emitOpError("initial value expected to be of type " ) |
1053 | << getType() << ", but was of type " << floatAttr.getType(); |
1054 | } |
1055 | } else if (!isa<emitc::OpaqueAttr>(initValue)) { |
1056 | return emitOpError("initial value should be a integer, float, elements " |
1057 | "or opaque attribute, but got " ) |
1058 | << initValue; |
1059 | } |
1060 | } |
1061 | if (getStaticSpecifier() && getExternSpecifier()) { |
1062 | return emitOpError("cannot have both static and extern specifiers" ); |
1063 | } |
1064 | return success(); |
1065 | } |
1066 | |
1067 | //===----------------------------------------------------------------------===// |
1068 | // GetGlobalOp |
1069 | //===----------------------------------------------------------------------===// |
1070 | |
1071 | LogicalResult |
1072 | GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1073 | // Verify that the type matches the type of the global variable. |
1074 | auto global = |
1075 | symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr()); |
1076 | if (!global) |
1077 | return emitOpError("'" ) |
1078 | << getName() << "' does not reference a valid emitc.global" ; |
1079 | |
1080 | Type resultType = getResult().getType(); |
1081 | if (global.getType() != resultType) |
1082 | return emitOpError("result type " ) |
1083 | << resultType << " does not match type " << global.getType() |
1084 | << " of the global @" << getName(); |
1085 | return success(); |
1086 | } |
1087 | |
1088 | //===----------------------------------------------------------------------===// |
1089 | // TableGen'd op method definitions |
1090 | //===----------------------------------------------------------------------===// |
1091 | |
1092 | #define GET_OP_CLASSES |
1093 | #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" |
1094 | |