1 | //===- EmitC.cpp - EmitC Dialect ------------------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
10 | #include "mlir/Dialect/EmitC/IR/EmitCTraits.h" |
11 | #include "mlir/IR/Builders.h" |
12 | #include "mlir/IR/BuiltinAttributes.h" |
13 | #include "mlir/IR/BuiltinTypes.h" |
14 | #include "mlir/IR/DialectImplementation.h" |
15 | #include "mlir/IR/IRMapping.h" |
16 | #include "mlir/IR/Types.h" |
17 | #include "mlir/Interfaces/FunctionImplementation.h" |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/ADT/StringExtras.h" |
20 | #include "llvm/ADT/TypeSwitch.h" |
21 | #include "llvm/Support/Casting.h" |
22 | #include "llvm/Support/FormatVariadic.h" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::emitc; |
26 | |
27 | #include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc" |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // EmitCDialect |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | void EmitCDialect::initialize() { |
34 | addOperations< |
35 | #define GET_OP_LIST |
36 | #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" |
37 | >(); |
38 | addTypes< |
39 | #define GET_TYPEDEF_LIST |
40 | #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc" |
41 | >(); |
42 | addAttributes< |
43 | #define GET_ATTRDEF_LIST |
44 | #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc" |
45 | >(); |
46 | } |
47 | |
48 | /// Materialize a single constant operation from a given attribute value with |
49 | /// the desired resultant type. |
50 | Operation *EmitCDialect::materializeConstant(OpBuilder &builder, |
51 | Attribute value, Type type, |
52 | Location loc) { |
53 | return builder.create<emitc::ConstantOp>(loc, type, value); |
54 | } |
55 | |
56 | /// Default callback for builders of ops carrying a region. Inserts a yield |
57 | /// without arguments. |
58 | void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) { |
59 | builder.create<emitc::YieldOp>(loc); |
60 | } |
61 | |
62 | bool mlir::emitc::isSupportedEmitCType(Type type) { |
63 | if (llvm::isa<emitc::OpaqueType>(type)) |
64 | return true; |
65 | if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type)) |
66 | return isSupportedEmitCType(ptrType.getPointee()); |
67 | if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) { |
68 | auto elemType = arrayType.getElementType(); |
69 | return !llvm::isa<emitc::ArrayType>(elemType) && |
70 | isSupportedEmitCType(elemType); |
71 | } |
72 | if (type.isIndex() || emitc::isPointerWideType(type)) |
73 | return true; |
74 | if (llvm::isa<IntegerType>(Val: type)) |
75 | return isSupportedIntegerType(type); |
76 | if (llvm::isa<FloatType>(Val: type)) |
77 | return isSupportedFloatType(type); |
78 | if (auto tensorType = llvm::dyn_cast<TensorType>(Val&: type)) { |
79 | if (!tensorType.hasStaticShape()) { |
80 | return false; |
81 | } |
82 | auto elemType = tensorType.getElementType(); |
83 | if (llvm::isa<emitc::ArrayType>(elemType)) { |
84 | return false; |
85 | } |
86 | return isSupportedEmitCType(type: elemType); |
87 | } |
88 | if (auto tupleType = llvm::dyn_cast<TupleType>(type)) { |
89 | return llvm::all_of(tupleType.getTypes(), [](Type type) { |
90 | return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type); |
91 | }); |
92 | } |
93 | return false; |
94 | } |
95 | |
96 | bool mlir::emitc::isSupportedIntegerType(Type type) { |
97 | if (auto intType = llvm::dyn_cast<IntegerType>(type)) { |
98 | switch (intType.getWidth()) { |
99 | case 1: |
100 | case 8: |
101 | case 16: |
102 | case 32: |
103 | case 64: |
104 | return true; |
105 | default: |
106 | return false; |
107 | } |
108 | } |
109 | return false; |
110 | } |
111 | |
112 | bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) { |
113 | return llvm::isa<IndexType, emitc::OpaqueType>(type) || |
114 | isSupportedIntegerType(type) || isPointerWideType(type); |
115 | } |
116 | |
117 | bool mlir::emitc::isSupportedFloatType(Type type) { |
118 | if (auto floatType = llvm::dyn_cast<FloatType>(type)) { |
119 | switch (floatType.getWidth()) { |
120 | case 16: { |
121 | if (llvm::isa<Float16Type, BFloat16Type>(type)) |
122 | return true; |
123 | return false; |
124 | } |
125 | case 32: |
126 | case 64: |
127 | return true; |
128 | default: |
129 | return false; |
130 | } |
131 | } |
132 | return false; |
133 | } |
134 | |
135 | bool mlir::emitc::isPointerWideType(Type type) { |
136 | return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>( |
137 | type); |
138 | } |
139 | |
140 | /// Check that the type of the initial value is compatible with the operations |
141 | /// result type. |
142 | static LogicalResult verifyInitializationAttribute(Operation *op, |
143 | Attribute value) { |
144 | assert(op->getNumResults() == 1 && "operation must have 1 result"); |
145 | |
146 | if (llvm::isa<emitc::OpaqueAttr>(value)) |
147 | return success(); |
148 | |
149 | if (llvm::isa<StringAttr>(Val: value)) |
150 | return op->emitOpError() |
151 | << "string attributes are not supported, use #emitc.opaque instead"; |
152 | |
153 | Type resultType = op->getResult(idx: 0).getType(); |
154 | if (auto lType = dyn_cast<LValueType>(resultType)) |
155 | resultType = lType.getValueType(); |
156 | Type attrType = cast<TypedAttr>(value).getType(); |
157 | |
158 | if (isPointerWideType(type: resultType) && attrType.isIndex()) |
159 | return success(); |
160 | |
161 | if (resultType != attrType) |
162 | return op->emitOpError() |
163 | << "requires attribute to either be an #emitc.opaque attribute or " |
164 | "it's type (" |
165 | << attrType << ") to match the op's result type ("<< resultType |
166 | << ")"; |
167 | |
168 | return success(); |
169 | } |
170 | |
171 | /// Parse a format string and return a list of its parts. |
172 | /// A part is either a StringRef that has to be printed as-is, or |
173 | /// a Placeholder which requires printing the next operand of the VerbatimOp. |
174 | /// In the format string, all `{}` are replaced by Placeholders, except if the |
175 | /// `{` is escaped by `{{` - then it doesn't start a placeholder. |
176 | template <class ArgType> |
177 | FailureOr<SmallVector<ReplacementItem>> |
178 | parseFormatString(StringRef toParse, ArgType fmtArgs, |
179 | std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>> |
180 | emitError = {}) { |
181 | SmallVector<ReplacementItem> items; |
182 | |
183 | // If there are not operands, the format string is not interpreted. |
184 | if (fmtArgs.empty()) { |
185 | items.push_back(Elt: toParse); |
186 | return items; |
187 | } |
188 | |
189 | while (!toParse.empty()) { |
190 | size_t idx = toParse.find(C: '{'); |
191 | if (idx == StringRef::npos) { |
192 | // No '{' |
193 | items.push_back(Elt: toParse); |
194 | break; |
195 | } |
196 | if (idx > 0) { |
197 | // Take all chars excluding the '{'. |
198 | items.push_back(Elt: toParse.take_front(N: idx)); |
199 | toParse = toParse.drop_front(N: idx); |
200 | continue; |
201 | } |
202 | if (toParse.size() < 2) { |
203 | return (*emitError)() |
204 | << "expected '}' after unescaped '{' at end of string"; |
205 | } |
206 | // toParse contains at least two characters and starts with `{`. |
207 | char nextChar = toParse[1]; |
208 | if (nextChar == '{') { |
209 | // Double '{{' -> '{' (escaping). |
210 | items.push_back(Elt: toParse.take_front(N: 1)); |
211 | toParse = toParse.drop_front(N: 2); |
212 | continue; |
213 | } |
214 | if (nextChar == '}') { |
215 | items.push_back(Elt: Placeholder{}); |
216 | toParse = toParse.drop_front(N: 2); |
217 | continue; |
218 | } |
219 | |
220 | if (emitError.has_value()) { |
221 | return (*emitError)() << "expected '}' after unescaped '{'"; |
222 | } |
223 | return failure(); |
224 | } |
225 | return items; |
226 | } |
227 | |
228 | //===----------------------------------------------------------------------===// |
229 | // AddOp |
230 | //===----------------------------------------------------------------------===// |
231 | |
232 | LogicalResult AddOp::verify() { |
233 | Type lhsType = getLhs().getType(); |
234 | Type rhsType = getRhs().getType(); |
235 | |
236 | if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType)) |
237 | return emitOpError("requires that at most one operand is a pointer"); |
238 | |
239 | if ((isa<emitc::PointerType>(lhsType) && |
240 | !isa<IntegerType, emitc::OpaqueType>(rhsType)) || |
241 | (isa<emitc::PointerType>(rhsType) && |
242 | !isa<IntegerType, emitc::OpaqueType>(lhsType))) |
243 | return emitOpError("requires that one operand is an integer or of opaque " |
244 | "type if the other is a pointer"); |
245 | |
246 | return success(); |
247 | } |
248 | |
249 | //===----------------------------------------------------------------------===// |
250 | // ApplyOp |
251 | //===----------------------------------------------------------------------===// |
252 | |
253 | LogicalResult ApplyOp::verify() { |
254 | StringRef applicableOperatorStr = getApplicableOperator(); |
255 | |
256 | // Applicable operator must not be empty. |
257 | if (applicableOperatorStr.empty()) |
258 | return emitOpError("applicable operator must not be empty"); |
259 | |
260 | // Only `*` and `&` are supported. |
261 | if (applicableOperatorStr != "&"&& applicableOperatorStr != "*") |
262 | return emitOpError("applicable operator is illegal"); |
263 | |
264 | Type operandType = getOperand().getType(); |
265 | Type resultType = getResult().getType(); |
266 | if (applicableOperatorStr == "&") { |
267 | if (!llvm::isa<emitc::LValueType>(operandType)) |
268 | return emitOpError("operand type must be an lvalue when applying `&`"); |
269 | if (!llvm::isa<emitc::PointerType>(resultType)) |
270 | return emitOpError("result type must be a pointer when applying `&`"); |
271 | } else { |
272 | if (!llvm::isa<emitc::PointerType>(operandType)) |
273 | return emitOpError("operand type must be a pointer when applying `*`"); |
274 | } |
275 | |
276 | return success(); |
277 | } |
278 | |
279 | //===----------------------------------------------------------------------===// |
280 | // AssignOp |
281 | //===----------------------------------------------------------------------===// |
282 | |
283 | /// The assign op requires that the assigned value's type matches the |
284 | /// assigned-to variable type. |
285 | LogicalResult emitc::AssignOp::verify() { |
286 | TypedValue<emitc::LValueType> variable = getVar(); |
287 | |
288 | if (!variable.getDefiningOp()) |
289 | return emitOpError() << "cannot assign to block argument"; |
290 | |
291 | Type valueType = getValue().getType(); |
292 | Type variableType = variable.getType().getValueType(); |
293 | if (variableType != valueType) |
294 | return emitOpError() << "requires value's type ("<< valueType |
295 | << ") to match variable's type ("<< variableType |
296 | << ")\n variable: "<< variable |
297 | << "\n value: "<< getValue() << "\n"; |
298 | return success(); |
299 | } |
300 | |
301 | //===----------------------------------------------------------------------===// |
302 | // CastOp |
303 | //===----------------------------------------------------------------------===// |
304 | |
305 | bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
306 | Type input = inputs.front(), output = outputs.front(); |
307 | |
308 | if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) { |
309 | if (auto pointerType = dyn_cast<emitc::PointerType>(output)) { |
310 | return (arrayType.getElementType() == pointerType.getPointee()) && |
311 | arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1; |
312 | } |
313 | return false; |
314 | } |
315 | |
316 | return ( |
317 | (emitc::isIntegerIndexOrOpaqueType(input) || |
318 | emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) && |
319 | (emitc::isIntegerIndexOrOpaqueType(output) || |
320 | emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output))); |
321 | } |
322 | |
323 | //===----------------------------------------------------------------------===// |
324 | // CallOpaqueOp |
325 | //===----------------------------------------------------------------------===// |
326 | |
327 | LogicalResult emitc::CallOpaqueOp::verify() { |
328 | // Callee must not be empty. |
329 | if (getCallee().empty()) |
330 | return emitOpError("callee must not be empty"); |
331 | |
332 | if (std::optional<ArrayAttr> argsAttr = getArgs()) { |
333 | for (Attribute arg : *argsAttr) { |
334 | auto intAttr = llvm::dyn_cast<IntegerAttr>(arg); |
335 | if (intAttr && llvm::isa<IndexType>(intAttr.getType())) { |
336 | int64_t index = intAttr.getInt(); |
337 | // Args with elements of type index must be in range |
338 | // [0..operands.size). |
339 | if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands()))) |
340 | return emitOpError("index argument is out of range"); |
341 | |
342 | // Args with elements of type ArrayAttr must have a type. |
343 | } else if (llvm::isa<ArrayAttr>( |
344 | arg) /*&& llvm::isa<NoneType>(arg.getType())*/) { |
345 | // FIXME: Array attributes never have types |
346 | return emitOpError("array argument has no type"); |
347 | } |
348 | } |
349 | } |
350 | |
351 | if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) { |
352 | for (Attribute tArg : *templateArgsAttr) { |
353 | if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg)) |
354 | return emitOpError("template argument has invalid type"); |
355 | } |
356 | } |
357 | |
358 | if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) { |
359 | return emitOpError() << "cannot return array type"; |
360 | } |
361 | |
362 | return success(); |
363 | } |
364 | |
365 | //===----------------------------------------------------------------------===// |
366 | // ConstantOp |
367 | //===----------------------------------------------------------------------===// |
368 | |
369 | LogicalResult emitc::ConstantOp::verify() { |
370 | Attribute value = getValueAttr(); |
371 | if (failed(verifyInitializationAttribute(getOperation(), value))) |
372 | return failure(); |
373 | if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) { |
374 | if (opaqueValue.getValue().empty()) |
375 | return emitOpError() << "value must not be empty"; |
376 | } |
377 | return success(); |
378 | } |
379 | |
380 | OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } |
381 | |
382 | //===----------------------------------------------------------------------===// |
383 | // ExpressionOp |
384 | //===----------------------------------------------------------------------===// |
385 | |
386 | Operation *ExpressionOp::getRootOp() { |
387 | auto yieldOp = cast<YieldOp>(getBody()->getTerminator()); |
388 | Value yieldedValue = yieldOp.getResult(); |
389 | Operation *rootOp = yieldedValue.getDefiningOp(); |
390 | assert(rootOp && "Yielded value not defined within expression"); |
391 | return rootOp; |
392 | } |
393 | |
394 | LogicalResult ExpressionOp::verify() { |
395 | Type resultType = getResult().getType(); |
396 | Region ®ion = getRegion(); |
397 | |
398 | Block &body = region.front(); |
399 | |
400 | if (!body.mightHaveTerminator()) |
401 | return emitOpError("must yield a value at termination"); |
402 | |
403 | auto yield = cast<YieldOp>(body.getTerminator()); |
404 | Value yieldResult = yield.getResult(); |
405 | |
406 | if (!yieldResult) |
407 | return emitOpError("must yield a value at termination"); |
408 | |
409 | Type yieldType = yieldResult.getType(); |
410 | |
411 | if (resultType != yieldType) |
412 | return emitOpError("requires yielded type to match return type"); |
413 | |
414 | for (Operation &op : region.front().without_terminator()) { |
415 | if (!op.hasTrait<OpTrait::emitc::CExpression>()) |
416 | return emitOpError("contains an unsupported operation"); |
417 | if (op.getNumResults() != 1) |
418 | return emitOpError("requires exactly one result for each operation"); |
419 | if (!op.getResult(0).hasOneUse()) |
420 | return emitOpError("requires exactly one use for each operation"); |
421 | } |
422 | |
423 | return success(); |
424 | } |
425 | |
426 | //===----------------------------------------------------------------------===// |
427 | // ForOp |
428 | //===----------------------------------------------------------------------===// |
429 | |
430 | void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, |
431 | Value ub, Value step, BodyBuilderFn bodyBuilder) { |
432 | OpBuilder::InsertionGuard g(builder); |
433 | result.addOperands({lb, ub, step}); |
434 | Type t = lb.getType(); |
435 | Region *bodyRegion = result.addRegion(); |
436 | Block *bodyBlock = builder.createBlock(bodyRegion); |
437 | bodyBlock->addArgument(t, result.location); |
438 | |
439 | // Create the default terminator if the builder is not provided. |
440 | if (!bodyBuilder) { |
441 | ForOp::ensureTerminator(*bodyRegion, builder, result.location); |
442 | } else { |
443 | OpBuilder::InsertionGuard guard(builder); |
444 | builder.setInsertionPointToStart(bodyBlock); |
445 | bodyBuilder(builder, result.location, bodyBlock->getArgument(0)); |
446 | } |
447 | } |
448 | |
449 | void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {} |
450 | |
451 | ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { |
452 | Builder &builder = parser.getBuilder(); |
453 | Type type; |
454 | |
455 | OpAsmParser::Argument inductionVariable; |
456 | OpAsmParser::UnresolvedOperand lb, ub, step; |
457 | |
458 | // Parse the induction variable followed by '='. |
459 | if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || |
460 | // Parse loop bounds. |
461 | parser.parseOperand(lb) || parser.parseKeyword("to") || |
462 | parser.parseOperand(ub) || parser.parseKeyword("step") || |
463 | parser.parseOperand(step)) |
464 | return failure(); |
465 | |
466 | // Parse the optional initial iteration arguments. |
467 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
468 | regionArgs.push_back(inductionVariable); |
469 | |
470 | // Parse optional type, else assume Index. |
471 | if (parser.parseOptionalColon()) |
472 | type = builder.getIndexType(); |
473 | else if (parser.parseType(type)) |
474 | return failure(); |
475 | |
476 | // Resolve input operands. |
477 | regionArgs.front().type = type; |
478 | if (parser.resolveOperand(lb, type, result.operands) || |
479 | parser.resolveOperand(ub, type, result.operands) || |
480 | parser.resolveOperand(step, type, result.operands)) |
481 | return failure(); |
482 | |
483 | // Parse the body region. |
484 | Region *body = result.addRegion(); |
485 | if (parser.parseRegion(*body, regionArgs)) |
486 | return failure(); |
487 | |
488 | ForOp::ensureTerminator(*body, builder, result.location); |
489 | |
490 | // Parse the optional attribute list. |
491 | if (parser.parseOptionalAttrDict(result.attributes)) |
492 | return failure(); |
493 | |
494 | return success(); |
495 | } |
496 | |
497 | void ForOp::print(OpAsmPrinter &p) { |
498 | p << " "<< getInductionVar() << " = "<< getLowerBound() << " to " |
499 | << getUpperBound() << " step "<< getStep(); |
500 | |
501 | p << ' '; |
502 | if (Type t = getInductionVar().getType(); !t.isIndex()) |
503 | p << " : "<< t << ' '; |
504 | p.printRegion(getRegion(), |
505 | /*printEntryBlockArgs=*/false, |
506 | /*printBlockTerminators=*/false); |
507 | p.printOptionalAttrDict((*this)->getAttrs()); |
508 | } |
509 | |
510 | LogicalResult ForOp::verifyRegions() { |
511 | // Check that the body defines as single block argument for the induction |
512 | // variable. |
513 | if (getInductionVar().getType() != getLowerBound().getType()) |
514 | return emitOpError( |
515 | "expected induction variable to be same type as bounds and step"); |
516 | |
517 | return success(); |
518 | } |
519 | |
520 | //===----------------------------------------------------------------------===// |
521 | // CallOp |
522 | //===----------------------------------------------------------------------===// |
523 | |
524 | LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
525 | // Check that the callee attribute was specified. |
526 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); |
527 | if (!fnAttr) |
528 | return emitOpError("requires a 'callee' symbol reference attribute"); |
529 | FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
530 | if (!fn) |
531 | return emitOpError() << "'"<< fnAttr.getValue() |
532 | << "' does not reference a valid function"; |
533 | |
534 | // Verify that the operand and result types match the callee. |
535 | auto fnType = fn.getFunctionType(); |
536 | if (fnType.getNumInputs() != getNumOperands()) |
537 | return emitOpError("incorrect number of operands for callee"); |
538 | |
539 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
540 | if (getOperand(i).getType() != fnType.getInput(i)) |
541 | return emitOpError("operand type mismatch: expected operand type ") |
542 | << fnType.getInput(i) << ", but provided " |
543 | << getOperand(i).getType() << " for operand number "<< i; |
544 | |
545 | if (fnType.getNumResults() != getNumResults()) |
546 | return emitOpError("incorrect number of results for callee"); |
547 | |
548 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
549 | if (getResult(i).getType() != fnType.getResult(i)) { |
550 | auto diag = emitOpError("result type mismatch at index ") << i; |
551 | diag.attachNote() << " op result types: "<< getResultTypes(); |
552 | diag.attachNote() << "function result types: "<< fnType.getResults(); |
553 | return diag; |
554 | } |
555 | |
556 | return success(); |
557 | } |
558 | |
559 | FunctionType CallOp::getCalleeType() { |
560 | return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
561 | } |
562 | |
563 | //===----------------------------------------------------------------------===// |
564 | // DeclareFuncOp |
565 | //===----------------------------------------------------------------------===// |
566 | |
567 | LogicalResult |
568 | DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
569 | // Check that the sym_name attribute was specified. |
570 | auto fnAttr = getSymNameAttr(); |
571 | if (!fnAttr) |
572 | return emitOpError("requires a 'sym_name' symbol reference attribute"); |
573 | FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
574 | if (!fn) |
575 | return emitOpError() << "'"<< fnAttr.getValue() |
576 | << "' does not reference a valid function"; |
577 | |
578 | return success(); |
579 | } |
580 | |
581 | //===----------------------------------------------------------------------===// |
582 | // FuncOp |
583 | //===----------------------------------------------------------------------===// |
584 | |
585 | void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
586 | FunctionType type, ArrayRef<NamedAttribute> attrs, |
587 | ArrayRef<DictionaryAttr> argAttrs) { |
588 | state.addAttribute(SymbolTable::getSymbolAttrName(), |
589 | builder.getStringAttr(name)); |
590 | state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); |
591 | state.attributes.append(attrs.begin(), attrs.end()); |
592 | state.addRegion(); |
593 | |
594 | if (argAttrs.empty()) |
595 | return; |
596 | assert(type.getNumInputs() == argAttrs.size()); |
597 | call_interface_impl::addArgAndResultAttrs( |
598 | builder, state, argAttrs, /*resultAttrs=*/std::nullopt, |
599 | getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); |
600 | } |
601 | |
602 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
603 | auto buildFuncType = |
604 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
605 | function_interface_impl::VariadicFlag, |
606 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
607 | |
608 | return function_interface_impl::parseFunctionOp( |
609 | parser, result, /*allowVariadic=*/false, |
610 | getFunctionTypeAttrName(result.name), buildFuncType, |
611 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
612 | } |
613 | |
614 | void FuncOp::print(OpAsmPrinter &p) { |
615 | function_interface_impl::printFunctionOp( |
616 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
617 | getArgAttrsAttrName(), getResAttrsAttrName()); |
618 | } |
619 | |
620 | LogicalResult FuncOp::verify() { |
621 | if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) { |
622 | return emitOpError("cannot have lvalue type as argument"); |
623 | } |
624 | |
625 | if (getNumResults() > 1) |
626 | return emitOpError("requires zero or exactly one result, but has ") |
627 | << getNumResults(); |
628 | |
629 | if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0])) |
630 | return emitOpError("cannot return array type"); |
631 | |
632 | return success(); |
633 | } |
634 | |
635 | //===----------------------------------------------------------------------===// |
636 | // ReturnOp |
637 | //===----------------------------------------------------------------------===// |
638 | |
639 | LogicalResult ReturnOp::verify() { |
640 | auto function = cast<FuncOp>((*this)->getParentOp()); |
641 | |
642 | // The operand number and types must match the function signature. |
643 | if (getNumOperands() != function.getNumResults()) |
644 | return emitOpError("has ") |
645 | << getNumOperands() << " operands, but enclosing function (@" |
646 | << function.getName() << ") returns "<< function.getNumResults(); |
647 | |
648 | if (function.getNumResults() == 1) |
649 | if (getOperand().getType() != function.getResultTypes()[0]) |
650 | return emitError() << "type of the return operand (" |
651 | << getOperand().getType() |
652 | << ") doesn't match function result type (" |
653 | << function.getResultTypes()[0] << ")" |
654 | << " in function @"<< function.getName(); |
655 | return success(); |
656 | } |
657 | |
658 | //===----------------------------------------------------------------------===// |
659 | // IfOp |
660 | //===----------------------------------------------------------------------===// |
661 | |
662 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
663 | bool addThenBlock, bool addElseBlock) { |
664 | assert((!addElseBlock || addThenBlock) && |
665 | "must not create else block w/o then block"); |
666 | result.addOperands(cond); |
667 | |
668 | // Add regions and blocks. |
669 | OpBuilder::InsertionGuard guard(builder); |
670 | Region *thenRegion = result.addRegion(); |
671 | if (addThenBlock) |
672 | builder.createBlock(thenRegion); |
673 | Region *elseRegion = result.addRegion(); |
674 | if (addElseBlock) |
675 | builder.createBlock(elseRegion); |
676 | } |
677 | |
678 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
679 | bool withElseRegion) { |
680 | result.addOperands(cond); |
681 | |
682 | // Build then region. |
683 | OpBuilder::InsertionGuard guard(builder); |
684 | Region *thenRegion = result.addRegion(); |
685 | builder.createBlock(thenRegion); |
686 | |
687 | // Build else region. |
688 | Region *elseRegion = result.addRegion(); |
689 | if (withElseRegion) { |
690 | builder.createBlock(elseRegion); |
691 | } |
692 | } |
693 | |
694 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
695 | function_ref<void(OpBuilder &, Location)> thenBuilder, |
696 | function_ref<void(OpBuilder &, Location)> elseBuilder) { |
697 | assert(thenBuilder && "the builder callback for 'then' must be present"); |
698 | result.addOperands(cond); |
699 | |
700 | // Build then region. |
701 | OpBuilder::InsertionGuard guard(builder); |
702 | Region *thenRegion = result.addRegion(); |
703 | builder.createBlock(thenRegion); |
704 | thenBuilder(builder, result.location); |
705 | |
706 | // Build else region. |
707 | Region *elseRegion = result.addRegion(); |
708 | if (elseBuilder) { |
709 | builder.createBlock(elseRegion); |
710 | elseBuilder(builder, result.location); |
711 | } |
712 | } |
713 | |
714 | ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { |
715 | // Create the regions for 'then'. |
716 | result.regions.reserve(2); |
717 | Region *thenRegion = result.addRegion(); |
718 | Region *elseRegion = result.addRegion(); |
719 | |
720 | Builder &builder = parser.getBuilder(); |
721 | OpAsmParser::UnresolvedOperand cond; |
722 | Type i1Type = builder.getIntegerType(1); |
723 | if (parser.parseOperand(cond) || |
724 | parser.resolveOperand(cond, i1Type, result.operands)) |
725 | return failure(); |
726 | // Parse the 'then' region. |
727 | if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) |
728 | return failure(); |
729 | IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); |
730 | |
731 | // If we find an 'else' keyword then parse the 'else' region. |
732 | if (!parser.parseOptionalKeyword("else")) { |
733 | if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) |
734 | return failure(); |
735 | IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); |
736 | } |
737 | |
738 | // Parse the optional attribute list. |
739 | if (parser.parseOptionalAttrDict(result.attributes)) |
740 | return failure(); |
741 | return success(); |
742 | } |
743 | |
744 | void IfOp::print(OpAsmPrinter &p) { |
745 | bool printBlockTerminators = false; |
746 | |
747 | p << " "<< getCondition(); |
748 | p << ' '; |
749 | p.printRegion(getThenRegion(), |
750 | /*printEntryBlockArgs=*/false, |
751 | /*printBlockTerminators=*/printBlockTerminators); |
752 | |
753 | // Print the 'else' regions if it exists and has a block. |
754 | Region &elseRegion = getElseRegion(); |
755 | if (!elseRegion.empty()) { |
756 | p << " else "; |
757 | p.printRegion(elseRegion, |
758 | /*printEntryBlockArgs=*/false, |
759 | /*printBlockTerminators=*/printBlockTerminators); |
760 | } |
761 | |
762 | p.printOptionalAttrDict((*this)->getAttrs()); |
763 | } |
764 | |
765 | /// Given the region at `index`, or the parent operation if `index` is None, |
766 | /// return the successor regions. These are the regions that may be selected |
767 | /// during the flow of control. `operands` is a set of optional attributes |
768 | /// that correspond to a constant value for each operand, or null if that |
769 | /// operand is not a constant. |
770 | void IfOp::getSuccessorRegions(RegionBranchPoint point, |
771 | SmallVectorImpl<RegionSuccessor> ®ions) { |
772 | // The `then` and the `else` region branch back to the parent operation. |
773 | if (!point.isParent()) { |
774 | regions.push_back(RegionSuccessor()); |
775 | return; |
776 | } |
777 | |
778 | regions.push_back(RegionSuccessor(&getThenRegion())); |
779 | |
780 | // Don't consider the else region if it is empty. |
781 | Region *elseRegion = &this->getElseRegion(); |
782 | if (elseRegion->empty()) |
783 | regions.push_back(RegionSuccessor()); |
784 | else |
785 | regions.push_back(RegionSuccessor(elseRegion)); |
786 | } |
787 | |
788 | void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, |
789 | SmallVectorImpl<RegionSuccessor> ®ions) { |
790 | FoldAdaptor adaptor(operands, *this); |
791 | auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); |
792 | if (!boolAttr || boolAttr.getValue()) |
793 | regions.emplace_back(&getThenRegion()); |
794 | |
795 | // If the else region is empty, execution continues after the parent op. |
796 | if (!boolAttr || !boolAttr.getValue()) { |
797 | if (!getElseRegion().empty()) |
798 | regions.emplace_back(&getElseRegion()); |
799 | else |
800 | regions.emplace_back(); |
801 | } |
802 | } |
803 | |
804 | void IfOp::getRegionInvocationBounds( |
805 | ArrayRef<Attribute> operands, |
806 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
807 | if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) { |
808 | // If the condition is known, then one region is known to be executed once |
809 | // and the other zero times. |
810 | invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); |
811 | invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); |
812 | } else { |
813 | // Non-constant condition. Each region may be executed 0 or 1 times. |
814 | invocationBounds.assign(2, {0, 1}); |
815 | } |
816 | } |
817 | |
818 | //===----------------------------------------------------------------------===// |
819 | // IncludeOp |
820 | //===----------------------------------------------------------------------===// |
821 | |
822 | void IncludeOp::print(OpAsmPrinter &p) { |
823 | bool standardInclude = getIsStandardInclude(); |
824 | |
825 | p << " "; |
826 | if (standardInclude) |
827 | p << "<"; |
828 | p << "\""<< getInclude() << "\""; |
829 | if (standardInclude) |
830 | p << ">"; |
831 | } |
832 | |
833 | ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) { |
834 | bool standardInclude = !parser.parseOptionalLess(); |
835 | |
836 | StringAttr include; |
837 | OptionalParseResult includeParseResult = |
838 | parser.parseOptionalAttribute(include, "include", result.attributes); |
839 | if (!includeParseResult.has_value()) |
840 | return parser.emitError(parser.getNameLoc()) << "expected string attribute"; |
841 | |
842 | if (standardInclude && parser.parseOptionalGreater()) |
843 | return parser.emitError(parser.getNameLoc()) |
844 | << "expected trailing '>' for standard include"; |
845 | |
846 | if (standardInclude) |
847 | result.addAttribute("is_standard_include", |
848 | UnitAttr::get(parser.getContext())); |
849 | |
850 | return success(); |
851 | } |
852 | |
853 | //===----------------------------------------------------------------------===// |
854 | // LiteralOp |
855 | //===----------------------------------------------------------------------===// |
856 | |
857 | /// The literal op requires a non-empty value. |
858 | LogicalResult emitc::LiteralOp::verify() { |
859 | if (getValue().empty()) |
860 | return emitOpError() << "value must not be empty"; |
861 | return success(); |
862 | } |
863 | //===----------------------------------------------------------------------===// |
864 | // SubOp |
865 | //===----------------------------------------------------------------------===// |
866 | |
867 | LogicalResult SubOp::verify() { |
868 | Type lhsType = getLhs().getType(); |
869 | Type rhsType = getRhs().getType(); |
870 | Type resultType = getResult().getType(); |
871 | |
872 | if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType)) |
873 | return emitOpError("rhs can only be a pointer if lhs is a pointer"); |
874 | |
875 | if (isa<emitc::PointerType>(lhsType) && |
876 | !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType)) |
877 | return emitOpError("requires that rhs is an integer, pointer or of opaque " |
878 | "type if lhs is a pointer"); |
879 | |
880 | if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) && |
881 | !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType)) |
882 | return emitOpError("requires that the result is an integer, ptrdiff_t or " |
883 | "of opaque type if lhs and rhs are pointers"); |
884 | return success(); |
885 | } |
886 | |
887 | //===----------------------------------------------------------------------===// |
888 | // VariableOp |
889 | //===----------------------------------------------------------------------===// |
890 | |
891 | LogicalResult emitc::VariableOp::verify() { |
892 | return verifyInitializationAttribute(getOperation(), getValueAttr()); |
893 | } |
894 | |
895 | //===----------------------------------------------------------------------===// |
896 | // YieldOp |
897 | //===----------------------------------------------------------------------===// |
898 | |
899 | LogicalResult emitc::YieldOp::verify() { |
900 | Value result = getResult(); |
901 | Operation *containingOp = getOperation()->getParentOp(); |
902 | |
903 | if (result && containingOp->getNumResults() != 1) |
904 | return emitOpError() << "yields a value not returned by parent"; |
905 | |
906 | if (!result && containingOp->getNumResults() != 0) |
907 | return emitOpError() << "does not yield a value to be returned by parent"; |
908 | |
909 | return success(); |
910 | } |
911 | |
912 | //===----------------------------------------------------------------------===// |
913 | // SubscriptOp |
914 | //===----------------------------------------------------------------------===// |
915 | |
916 | LogicalResult emitc::SubscriptOp::verify() { |
917 | // Checks for array operand. |
918 | if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) { |
919 | // Check number of indices. |
920 | if (getIndices().size() != (size_t)arrayType.getRank()) { |
921 | return emitOpError() << "on array operand requires number of indices (" |
922 | << getIndices().size() |
923 | << ") to match the rank of the array type (" |
924 | << arrayType.getRank() << ")"; |
925 | } |
926 | // Check types of index operands. |
927 | for (unsigned i = 0, e = getIndices().size(); i != e; ++i) { |
928 | Type type = getIndices()[i].getType(); |
929 | if (!isIntegerIndexOrOpaqueType(type)) { |
930 | return emitOpError() << "on array operand requires index operand "<< i |
931 | << " to be integer-like, but got "<< type; |
932 | } |
933 | } |
934 | // Check element type. |
935 | Type elementType = arrayType.getElementType(); |
936 | Type resultType = getType().getValueType(); |
937 | if (elementType != resultType) { |
938 | return emitOpError() << "on array operand requires element type (" |
939 | << elementType << ") and result type ("<< resultType |
940 | << ") to match"; |
941 | } |
942 | return success(); |
943 | } |
944 | |
945 | // Checks for pointer operand. |
946 | if (auto pointerType = |
947 | llvm::dyn_cast<emitc::PointerType>(getValue().getType())) { |
948 | // Check number of indices. |
949 | if (getIndices().size() != 1) { |
950 | return emitOpError() |
951 | << "on pointer operand requires one index operand, but got " |
952 | << getIndices().size(); |
953 | } |
954 | // Check types of index operand. |
955 | Type type = getIndices()[0].getType(); |
956 | if (!isIntegerIndexOrOpaqueType(type)) { |
957 | return emitOpError() << "on pointer operand requires index operand to be " |
958 | "integer-like, but got " |
959 | << type; |
960 | } |
961 | // Check pointee type. |
962 | Type pointeeType = pointerType.getPointee(); |
963 | Type resultType = getType().getValueType(); |
964 | if (pointeeType != resultType) { |
965 | return emitOpError() << "on pointer operand requires pointee type (" |
966 | << pointeeType << ") and result type ("<< resultType |
967 | << ") to match"; |
968 | } |
969 | return success(); |
970 | } |
971 | |
972 | // The operand has opaque type, so we can't assume anything about the number |
973 | // or types of index operands. |
974 | return success(); |
975 | } |
976 | |
977 | //===----------------------------------------------------------------------===// |
978 | // VerbatimOp |
979 | //===----------------------------------------------------------------------===// |
980 | |
981 | LogicalResult emitc::VerbatimOp::verify() { |
982 | auto errorCallback = [&]() -> InFlightDiagnostic { |
983 | return this->emitOpError(); |
984 | }; |
985 | FailureOr<SmallVector<ReplacementItem>> fmt = |
986 | ::parseFormatString(getValue(), getFmtArgs(), errorCallback); |
987 | if (failed(fmt)) |
988 | return failure(); |
989 | |
990 | size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) { |
991 | return std::holds_alternative<Placeholder>(item); |
992 | }); |
993 | |
994 | if (numPlaceholders != getFmtArgs().size()) { |
995 | return emitOpError() |
996 | << "requires operands for each placeholder in the format string"; |
997 | } |
998 | return success(); |
999 | } |
1000 | |
1001 | FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() { |
1002 | // Error checking is done in verify. |
1003 | return ::parseFormatString(getValue(), getFmtArgs()); |
1004 | } |
1005 | |
1006 | //===----------------------------------------------------------------------===// |
1007 | // EmitC Enums |
1008 | //===----------------------------------------------------------------------===// |
1009 | |
1010 | #include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc" |
1011 | |
1012 | //===----------------------------------------------------------------------===// |
1013 | // EmitC Attributes |
1014 | //===----------------------------------------------------------------------===// |
1015 | |
1016 | #define GET_ATTRDEF_CLASSES |
1017 | #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc" |
1018 | |
1019 | //===----------------------------------------------------------------------===// |
1020 | // EmitC Types |
1021 | //===----------------------------------------------------------------------===// |
1022 | |
1023 | #define GET_TYPEDEF_CLASSES |
1024 | #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc" |
1025 | |
1026 | //===----------------------------------------------------------------------===// |
1027 | // ArrayType |
1028 | //===----------------------------------------------------------------------===// |
1029 | |
1030 | Type emitc::ArrayType::parse(AsmParser &parser) { |
1031 | if (parser.parseLess()) |
1032 | return Type(); |
1033 | |
1034 | SmallVector<int64_t, 4> dimensions; |
1035 | if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false, |
1036 | /*withTrailingX=*/true)) |
1037 | return Type(); |
1038 | // Parse the element type. |
1039 | auto typeLoc = parser.getCurrentLocation(); |
1040 | Type elementType; |
1041 | if (parser.parseType(elementType)) |
1042 | return Type(); |
1043 | |
1044 | // Check that array is formed from allowed types. |
1045 | if (!isValidElementType(elementType)) |
1046 | return parser.emitError(typeLoc, "invalid array element type '") |
1047 | << elementType << "'", |
1048 | Type(); |
1049 | if (parser.parseGreater()) |
1050 | return Type(); |
1051 | return parser.getChecked<ArrayType>(dimensions, elementType); |
1052 | } |
1053 | |
1054 | void emitc::ArrayType::print(AsmPrinter &printer) const { |
1055 | printer << "<"; |
1056 | for (int64_t dim : getShape()) { |
1057 | printer << dim << 'x'; |
1058 | } |
1059 | printer.printType(getElementType()); |
1060 | printer << ">"; |
1061 | } |
1062 | |
1063 | LogicalResult emitc::ArrayType::verify( |
1064 | ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, |
1065 | ::llvm::ArrayRef<int64_t> shape, Type elementType) { |
1066 | if (shape.empty()) |
1067 | return emitError() << "shape must not be empty"; |
1068 | |
1069 | for (int64_t dim : shape) { |
1070 | if (dim < 0) |
1071 | return emitError() << "dimensions must have non-negative size"; |
1072 | } |
1073 | |
1074 | if (!elementType) |
1075 | return emitError() << "element type must not be none"; |
1076 | |
1077 | if (!isValidElementType(elementType)) |
1078 | return emitError() << "invalid array element type"; |
1079 | |
1080 | return success(); |
1081 | } |
1082 | |
1083 | emitc::ArrayType |
1084 | emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape, |
1085 | Type elementType) const { |
1086 | if (!shape) |
1087 | return emitc::ArrayType::get(getShape(), elementType); |
1088 | return emitc::ArrayType::get(*shape, elementType); |
1089 | } |
1090 | |
1091 | //===----------------------------------------------------------------------===// |
1092 | // LValueType |
1093 | //===----------------------------------------------------------------------===// |
1094 | |
1095 | LogicalResult mlir::emitc::LValueType::verify( |
1096 | llvm::function_ref<mlir::InFlightDiagnostic()> emitError, |
1097 | mlir::Type value) { |
1098 | // Check that the wrapped type is valid. This especially forbids nested |
1099 | // lvalue types. |
1100 | if (!isSupportedEmitCType(value)) |
1101 | return emitError() |
1102 | << "!emitc.lvalue must wrap supported emitc type, but got "<< value; |
1103 | |
1104 | if (llvm::isa<emitc::ArrayType>(value)) |
1105 | return emitError() << "!emitc.lvalue cannot wrap !emitc.array type"; |
1106 | |
1107 | return success(); |
1108 | } |
1109 | |
1110 | //===----------------------------------------------------------------------===// |
1111 | // OpaqueType |
1112 | //===----------------------------------------------------------------------===// |
1113 | |
1114 | LogicalResult mlir::emitc::OpaqueType::verify( |
1115 | llvm::function_ref<mlir::InFlightDiagnostic()> emitError, |
1116 | llvm::StringRef value) { |
1117 | if (value.empty()) { |
1118 | return emitError() << "expected non empty string in !emitc.opaque type"; |
1119 | } |
1120 | if (value.back() == '*') { |
1121 | return emitError() << "pointer not allowed as outer type with " |
1122 | "!emitc.opaque, use !emitc.ptr instead"; |
1123 | } |
1124 | return success(); |
1125 | } |
1126 | |
1127 | //===----------------------------------------------------------------------===// |
1128 | // PointerType |
1129 | //===----------------------------------------------------------------------===// |
1130 | |
1131 | LogicalResult mlir::emitc::PointerType::verify( |
1132 | llvm::function_ref<mlir::InFlightDiagnostic()> emitError, Type value) { |
1133 | if (llvm::isa<emitc::LValueType>(value)) |
1134 | return emitError() << "pointers to lvalues are not allowed"; |
1135 | |
1136 | return success(); |
1137 | } |
1138 | |
1139 | //===----------------------------------------------------------------------===// |
1140 | // GlobalOp |
1141 | //===----------------------------------------------------------------------===// |
1142 | static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, |
1143 | TypeAttr type, |
1144 | Attribute initialValue) { |
1145 | p << type; |
1146 | if (initialValue) { |
1147 | p << " = "; |
1148 | p.printAttributeWithoutType(attr: initialValue); |
1149 | } |
1150 | } |
1151 | |
1152 | static Type getInitializerTypeForGlobal(Type type) { |
1153 | if (auto array = llvm::dyn_cast<ArrayType>(type)) |
1154 | return RankedTensorType::get(array.getShape(), array.getElementType()); |
1155 | return type; |
1156 | } |
1157 | |
1158 | static ParseResult |
1159 | parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, |
1160 | Attribute &initialValue) { |
1161 | Type type; |
1162 | if (parser.parseType(result&: type)) |
1163 | return failure(); |
1164 | |
1165 | typeAttr = TypeAttr::get(type); |
1166 | |
1167 | if (parser.parseOptionalEqual()) |
1168 | return success(); |
1169 | |
1170 | if (parser.parseAttribute(result&: initialValue, type: getInitializerTypeForGlobal(type))) |
1171 | return failure(); |
1172 | |
1173 | if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>( |
1174 | initialValue)) |
1175 | return parser.emitError(loc: parser.getNameLoc()) |
1176 | << "initial value should be a integer, float, elements or opaque " |
1177 | "attribute"; |
1178 | return success(); |
1179 | } |
1180 | |
1181 | LogicalResult GlobalOp::verify() { |
1182 | if (!isSupportedEmitCType(getType())) { |
1183 | return emitOpError("expected valid emitc type"); |
1184 | } |
1185 | if (getInitialValue().has_value()) { |
1186 | Attribute initValue = getInitialValue().value(); |
1187 | // Check that the type of the initial value is compatible with the type of |
1188 | // the global variable. |
1189 | if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) { |
1190 | auto arrayType = llvm::dyn_cast<ArrayType>(getType()); |
1191 | if (!arrayType) |
1192 | return emitOpError("expected array type, but got ") << getType(); |
1193 | |
1194 | Type initType = elementsAttr.getType(); |
1195 | Type tensorType = getInitializerTypeForGlobal(getType()); |
1196 | if (initType != tensorType) { |
1197 | return emitOpError("initial value expected to be of type ") |
1198 | << getType() << ", but was of type "<< initType; |
1199 | } |
1200 | } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) { |
1201 | if (intAttr.getType() != getType()) { |
1202 | return emitOpError("initial value expected to be of type ") |
1203 | << getType() << ", but was of type "<< intAttr.getType(); |
1204 | } |
1205 | } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) { |
1206 | if (floatAttr.getType() != getType()) { |
1207 | return emitOpError("initial value expected to be of type ") |
1208 | << getType() << ", but was of type "<< floatAttr.getType(); |
1209 | } |
1210 | } else if (!isa<emitc::OpaqueAttr>(initValue)) { |
1211 | return emitOpError("initial value should be a integer, float, elements " |
1212 | "or opaque attribute, but got ") |
1213 | << initValue; |
1214 | } |
1215 | } |
1216 | if (getStaticSpecifier() && getExternSpecifier()) { |
1217 | return emitOpError("cannot have both static and extern specifiers"); |
1218 | } |
1219 | return success(); |
1220 | } |
1221 | |
1222 | //===----------------------------------------------------------------------===// |
1223 | // GetGlobalOp |
1224 | //===----------------------------------------------------------------------===// |
1225 | |
1226 | LogicalResult |
1227 | GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1228 | // Verify that the type matches the type of the global variable. |
1229 | auto global = |
1230 | symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr()); |
1231 | if (!global) |
1232 | return emitOpError("'") |
1233 | << getName() << "' does not reference a valid emitc.global"; |
1234 | |
1235 | Type resultType = getResult().getType(); |
1236 | Type globalType = global.getType(); |
1237 | |
1238 | // global has array type |
1239 | if (llvm::isa<ArrayType>(globalType)) { |
1240 | if (globalType != resultType) |
1241 | return emitOpError("on array type expects result type ") |
1242 | << resultType << " to match type "<< globalType |
1243 | << " of the global @"<< getName(); |
1244 | return success(); |
1245 | } |
1246 | |
1247 | // global has non-array type |
1248 | auto lvalueType = dyn_cast<LValueType>(resultType); |
1249 | if (!lvalueType || lvalueType.getValueType() != globalType) |
1250 | return emitOpError("on non-array type expects result inner type ") |
1251 | << lvalueType.getValueType() << " to match type "<< globalType |
1252 | << " of the global @"<< getName(); |
1253 | return success(); |
1254 | } |
1255 | |
1256 | //===----------------------------------------------------------------------===// |
1257 | // SwitchOp |
1258 | //===----------------------------------------------------------------------===// |
1259 | |
1260 | /// Parse the case regions and values. |
1261 | static ParseResult |
1262 | parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, |
1263 | SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) { |
1264 | SmallVector<int64_t> caseValues; |
1265 | while (succeeded(Result: parser.parseOptionalKeyword(keyword: "case"))) { |
1266 | int64_t value; |
1267 | Region ®ion = *caseRegions.emplace_back(Args: std::make_unique<Region>()); |
1268 | if (parser.parseInteger(result&: value) || |
1269 | parser.parseRegion(region, /*arguments=*/{})) |
1270 | return failure(); |
1271 | caseValues.push_back(Elt: value); |
1272 | } |
1273 | cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues); |
1274 | return success(); |
1275 | } |
1276 | |
1277 | /// Print the case regions and values. |
1278 | static void printSwitchCases(OpAsmPrinter &p, Operation *op, |
1279 | DenseI64ArrayAttr cases, RegionRange caseRegions) { |
1280 | for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { |
1281 | p.printNewline(); |
1282 | p << "case "<< value << ' '; |
1283 | p.printRegion(*region, /*printEntryBlockArgs=*/false); |
1284 | } |
1285 | } |
1286 | |
1287 | static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, |
1288 | const Twine &name) { |
1289 | auto yield = dyn_cast<emitc::YieldOp>(region.front().back()); |
1290 | if (!yield) |
1291 | return op.emitOpError("expected region to end with emitc.yield, but got ") |
1292 | << region.front().back().getName(); |
1293 | |
1294 | if (yield.getNumOperands() != 0) { |
1295 | return (op.emitOpError("expected each region to return ") |
1296 | << "0 values, but "<< name << " returns " |
1297 | << yield.getNumOperands()) |
1298 | .attachNote(yield.getLoc()) |
1299 | << "see yield operation here"; |
1300 | } |
1301 | |
1302 | return success(); |
1303 | } |
1304 | |
1305 | LogicalResult emitc::SwitchOp::verify() { |
1306 | if (!isIntegerIndexOrOpaqueType(getArg().getType())) |
1307 | return emitOpError("unsupported type ") << getArg().getType(); |
1308 | |
1309 | if (getCases().size() != getCaseRegions().size()) { |
1310 | return emitOpError("has ") |
1311 | << getCaseRegions().size() << " case regions but " |
1312 | << getCases().size() << " case values"; |
1313 | } |
1314 | |
1315 | DenseSet<int64_t> valueSet; |
1316 | for (int64_t value : getCases()) |
1317 | if (!valueSet.insert(value).second) |
1318 | return emitOpError("has duplicate case value: ") << value; |
1319 | |
1320 | if (failed(verifyRegion(*this, getDefaultRegion(), "default region"))) |
1321 | return failure(); |
1322 | |
1323 | for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions())) |
1324 | if (failed(verifyRegion(*this, caseRegion, "case region #"+ Twine(idx)))) |
1325 | return failure(); |
1326 | |
1327 | return success(); |
1328 | } |
1329 | |
1330 | unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); } |
1331 | |
1332 | Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); } |
1333 | |
1334 | Block &emitc::SwitchOp::getCaseBlock(unsigned idx) { |
1335 | assert(idx < getNumCases() && "case index out-of-bounds"); |
1336 | return getCaseRegions()[idx].front(); |
1337 | } |
1338 | |
1339 | void SwitchOp::getSuccessorRegions( |
1340 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { |
1341 | llvm::append_range(successors, getRegions()); |
1342 | } |
1343 | |
1344 | void SwitchOp::getEntrySuccessorRegions( |
1345 | ArrayRef<Attribute> operands, |
1346 | SmallVectorImpl<RegionSuccessor> &successors) { |
1347 | FoldAdaptor adaptor(operands, *this); |
1348 | |
1349 | // If a constant was not provided, all regions are possible successors. |
1350 | auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg()); |
1351 | if (!arg) { |
1352 | llvm::append_range(successors, getRegions()); |
1353 | return; |
1354 | } |
1355 | |
1356 | // Otherwise, try to find a case with a matching value. If not, the |
1357 | // default region is the only successor. |
1358 | for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) { |
1359 | if (caseValue == arg.getInt()) { |
1360 | successors.emplace_back(&caseRegion); |
1361 | return; |
1362 | } |
1363 | } |
1364 | successors.emplace_back(&getDefaultRegion()); |
1365 | } |
1366 | |
1367 | void SwitchOp::getRegionInvocationBounds( |
1368 | ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { |
1369 | auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front()); |
1370 | if (!operandValue) { |
1371 | // All regions are invoked at most once. |
1372 | bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1)); |
1373 | return; |
1374 | } |
1375 | |
1376 | unsigned liveIndex = getNumRegions() - 1; |
1377 | const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt()); |
1378 | |
1379 | liveIndex = iteratorToInt != getCases().end() |
1380 | ? std::distance(getCases().begin(), iteratorToInt) |
1381 | : liveIndex; |
1382 | |
1383 | for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum; |
1384 | ++regIndex) |
1385 | bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex); |
1386 | } |
1387 | |
1388 | //===----------------------------------------------------------------------===// |
1389 | // FileOp |
1390 | //===----------------------------------------------------------------------===// |
1391 | void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) { |
1392 | state.addRegion()->emplaceBlock(); |
1393 | state.attributes.push_back( |
1394 | builder.getNamedAttr("id", builder.getStringAttr(id))); |
1395 | } |
1396 | |
1397 | //===----------------------------------------------------------------------===// |
1398 | // TableGen'd op method definitions |
1399 | //===----------------------------------------------------------------------===// |
1400 | |
1401 | #define GET_OP_CLASSES |
1402 | #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" |
1403 |
Definitions
- buildTerminatedBody
- isSupportedEmitCType
- isSupportedIntegerType
- isIntegerIndexOrOpaqueType
- isSupportedFloatType
- isPointerWideType
- verifyInitializationAttribute
- parseFormatString
- printEmitCGlobalOpTypeAndInitialValue
- getInitializerTypeForGlobal
- parseEmitCGlobalOpTypeAndInitialValue
- parseSwitchCases
- printSwitchCases
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more