1//===- EmitC.cpp - EmitC Dialect ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/EmitC/IR/EmitC.h"
10#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/Types.h"
16#include "mlir/Interfaces/FunctionImplementation.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/Casting.h"
20
21using namespace mlir;
22using namespace mlir::emitc;
23
24#include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
25
26//===----------------------------------------------------------------------===//
27// EmitCDialect
28//===----------------------------------------------------------------------===//
29
30void EmitCDialect::initialize() {
31 addOperations<
32#define GET_OP_LIST
33#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
34 >();
35 addTypes<
36#define GET_TYPEDEF_LIST
37#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
38 >();
39 addAttributes<
40#define GET_ATTRDEF_LIST
41#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
42 >();
43}
44
45/// Materialize a single constant operation from a given attribute value with
46/// the desired resultant type.
47Operation *EmitCDialect::materializeConstant(OpBuilder &builder,
48 Attribute value, Type type,
49 Location loc) {
50 return builder.create<emitc::ConstantOp>(location: loc, args&: type, args&: value);
51}
52
53/// Default callback for builders of ops carrying a region. Inserts a yield
54/// without arguments.
55void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
56 builder.create<emitc::YieldOp>(location: loc);
57}
58
59bool mlir::emitc::isSupportedEmitCType(Type type) {
60 if (llvm::isa<emitc::OpaqueType>(Val: type))
61 return true;
62 if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(Val&: type))
63 return isSupportedEmitCType(type: ptrType.getPointee());
64 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(Val&: type)) {
65 auto elemType = arrayType.getElementType();
66 return !llvm::isa<emitc::ArrayType>(Val: elemType) &&
67 isSupportedEmitCType(type: elemType);
68 }
69 if (type.isIndex() || emitc::isPointerWideType(type))
70 return true;
71 if (llvm::isa<IntegerType>(Val: type))
72 return isSupportedIntegerType(type);
73 if (llvm::isa<FloatType>(Val: type))
74 return isSupportedFloatType(type);
75 if (auto tensorType = llvm::dyn_cast<TensorType>(Val&: type)) {
76 if (!tensorType.hasStaticShape()) {
77 return false;
78 }
79 auto elemType = tensorType.getElementType();
80 if (llvm::isa<emitc::ArrayType>(Val: elemType)) {
81 return false;
82 }
83 return isSupportedEmitCType(type: elemType);
84 }
85 if (auto tupleType = llvm::dyn_cast<TupleType>(Val&: type)) {
86 return llvm::all_of(Range: tupleType.getTypes(), P: [](Type type) {
87 return !llvm::isa<emitc::ArrayType>(Val: type) && isSupportedEmitCType(type);
88 });
89 }
90 return false;
91}
92
93bool mlir::emitc::isSupportedIntegerType(Type type) {
94 if (auto intType = llvm::dyn_cast<IntegerType>(Val&: type)) {
95 switch (intType.getWidth()) {
96 case 1:
97 case 8:
98 case 16:
99 case 32:
100 case 64:
101 return true;
102 default:
103 return false;
104 }
105 }
106 return false;
107}
108
109bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
110 return llvm::isa<IndexType, emitc::OpaqueType>(Val: type) ||
111 isSupportedIntegerType(type) || isPointerWideType(type);
112}
113
114bool mlir::emitc::isSupportedFloatType(Type type) {
115 if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) {
116 switch (floatType.getWidth()) {
117 case 16: {
118 if (llvm::isa<Float16Type, BFloat16Type>(Val: type))
119 return true;
120 return false;
121 }
122 case 32:
123 case 64:
124 return true;
125 default:
126 return false;
127 }
128 }
129 return false;
130}
131
132bool mlir::emitc::isPointerWideType(Type type) {
133 return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
134 Val: type);
135}
136
137/// Check that the type of the initial value is compatible with the operations
138/// result type.
139static LogicalResult verifyInitializationAttribute(Operation *op,
140 Attribute value) {
141 assert(op->getNumResults() == 1 && "operation must have 1 result");
142
143 if (llvm::isa<emitc::OpaqueAttr>(Val: value))
144 return success();
145
146 if (llvm::isa<StringAttr>(Val: value))
147 return op->emitOpError()
148 << "string attributes are not supported, use #emitc.opaque instead";
149
150 Type resultType = op->getResult(idx: 0).getType();
151 if (auto lType = dyn_cast<LValueType>(Val&: resultType))
152 resultType = lType.getValueType();
153 Type attrType = cast<TypedAttr>(Val&: value).getType();
154
155 if (isPointerWideType(type: resultType) && attrType.isIndex())
156 return success();
157
158 if (resultType != attrType)
159 return op->emitOpError()
160 << "requires attribute to either be an #emitc.opaque attribute or "
161 "it's type ("
162 << attrType << ") to match the op's result type (" << resultType
163 << ")";
164
165 return success();
166}
167
168/// Parse a format string and return a list of its parts.
169/// A part is either a StringRef that has to be printed as-is, or
170/// a Placeholder which requires printing the next operand of the VerbatimOp.
171/// In the format string, all `{}` are replaced by Placeholders, except if the
172/// `{` is escaped by `{{` - then it doesn't start a placeholder.
173template <class ArgType>
174FailureOr<SmallVector<ReplacementItem>>
175parseFormatString(StringRef toParse, ArgType fmtArgs,
176 std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>>
177 emitError = {}) {
178 SmallVector<ReplacementItem> items;
179
180 // If there are not operands, the format string is not interpreted.
181 if (fmtArgs.empty()) {
182 items.push_back(Elt: toParse);
183 return items;
184 }
185
186 while (!toParse.empty()) {
187 size_t idx = toParse.find(C: '{');
188 if (idx == StringRef::npos) {
189 // No '{'
190 items.push_back(Elt: toParse);
191 break;
192 }
193 if (idx > 0) {
194 // Take all chars excluding the '{'.
195 items.push_back(Elt: toParse.take_front(N: idx));
196 toParse = toParse.drop_front(N: idx);
197 continue;
198 }
199 if (toParse.size() < 2) {
200 return (*emitError)()
201 << "expected '}' after unescaped '{' at end of string";
202 }
203 // toParse contains at least two characters and starts with `{`.
204 char nextChar = toParse[1];
205 if (nextChar == '{') {
206 // Double '{{' -> '{' (escaping).
207 items.push_back(Elt: toParse.take_front(N: 1));
208 toParse = toParse.drop_front(N: 2);
209 continue;
210 }
211 if (nextChar == '}') {
212 items.push_back(Elt: Placeholder{});
213 toParse = toParse.drop_front(N: 2);
214 continue;
215 }
216
217 if (emitError.has_value()) {
218 return (*emitError)() << "expected '}' after unescaped '{'";
219 }
220 return failure();
221 }
222 return items;
223}
224
225//===----------------------------------------------------------------------===//
226// AddOp
227//===----------------------------------------------------------------------===//
228
229LogicalResult AddOp::verify() {
230 Type lhsType = getLhs().getType();
231 Type rhsType = getRhs().getType();
232
233 if (isa<emitc::PointerType>(Val: lhsType) && isa<emitc::PointerType>(Val: rhsType))
234 return emitOpError(message: "requires that at most one operand is a pointer");
235
236 if ((isa<emitc::PointerType>(Val: lhsType) &&
237 !isa<IntegerType, emitc::OpaqueType>(Val: rhsType)) ||
238 (isa<emitc::PointerType>(Val: rhsType) &&
239 !isa<IntegerType, emitc::OpaqueType>(Val: lhsType)))
240 return emitOpError(message: "requires that one operand is an integer or of opaque "
241 "type if the other is a pointer");
242
243 return success();
244}
245
246//===----------------------------------------------------------------------===//
247// ApplyOp
248//===----------------------------------------------------------------------===//
249
250LogicalResult ApplyOp::verify() {
251 StringRef applicableOperatorStr = getApplicableOperator();
252
253 // Applicable operator must not be empty.
254 if (applicableOperatorStr.empty())
255 return emitOpError(message: "applicable operator must not be empty");
256
257 // Only `*` and `&` are supported.
258 if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
259 return emitOpError(message: "applicable operator is illegal");
260
261 Type operandType = getOperand().getType();
262 Type resultType = getResult().getType();
263 if (applicableOperatorStr == "&") {
264 if (!llvm::isa<emitc::LValueType>(Val: operandType))
265 return emitOpError(message: "operand type must be an lvalue when applying `&`");
266 if (!llvm::isa<emitc::PointerType>(Val: resultType))
267 return emitOpError(message: "result type must be a pointer when applying `&`");
268 } else {
269 if (!llvm::isa<emitc::PointerType>(Val: operandType))
270 return emitOpError(message: "operand type must be a pointer when applying `*`");
271 }
272
273 return success();
274}
275
276//===----------------------------------------------------------------------===//
277// AssignOp
278//===----------------------------------------------------------------------===//
279
280/// The assign op requires that the assigned value's type matches the
281/// assigned-to variable type.
282LogicalResult emitc::AssignOp::verify() {
283 TypedValue<emitc::LValueType> variable = getVar();
284
285 if (!variable.getDefiningOp())
286 return emitOpError() << "cannot assign to block argument";
287
288 Type valueType = getValue().getType();
289 Type variableType = variable.getType().getValueType();
290 if (variableType != valueType)
291 return emitOpError() << "requires value's type (" << valueType
292 << ") to match variable's type (" << variableType
293 << ")\n variable: " << variable
294 << "\n value: " << getValue() << "\n";
295 return success();
296}
297
298//===----------------------------------------------------------------------===//
299// CastOp
300//===----------------------------------------------------------------------===//
301
302bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
303 Type input = inputs.front(), output = outputs.front();
304
305 if (auto arrayType = dyn_cast<emitc::ArrayType>(Val&: input)) {
306 if (auto pointerType = dyn_cast<emitc::PointerType>(Val&: output)) {
307 return (arrayType.getElementType() == pointerType.getPointee()) &&
308 arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
309 }
310 return false;
311 }
312
313 return (
314 (emitc::isIntegerIndexOrOpaqueType(type: input) ||
315 emitc::isSupportedFloatType(type: input) || isa<emitc::PointerType>(Val: input)) &&
316 (emitc::isIntegerIndexOrOpaqueType(type: output) ||
317 emitc::isSupportedFloatType(type: output) || isa<emitc::PointerType>(Val: output)));
318}
319
320//===----------------------------------------------------------------------===//
321// CallOpaqueOp
322//===----------------------------------------------------------------------===//
323
324LogicalResult emitc::CallOpaqueOp::verify() {
325 // Callee must not be empty.
326 if (getCallee().empty())
327 return emitOpError(message: "callee must not be empty");
328
329 if (std::optional<ArrayAttr> argsAttr = getArgs()) {
330 for (Attribute arg : *argsAttr) {
331 auto intAttr = llvm::dyn_cast<IntegerAttr>(Val&: arg);
332 if (intAttr && llvm::isa<IndexType>(Val: intAttr.getType())) {
333 int64_t index = intAttr.getInt();
334 // Args with elements of type index must be in range
335 // [0..operands.size).
336 if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
337 return emitOpError(message: "index argument is out of range");
338
339 // Args with elements of type ArrayAttr must have a type.
340 } else if (llvm::isa<ArrayAttr>(
341 Val: arg) /*&& llvm::isa<NoneType>(arg.getType())*/) {
342 // FIXME: Array attributes never have types
343 return emitOpError(message: "array argument has no type");
344 }
345 }
346 }
347
348 if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
349 for (Attribute tArg : *templateArgsAttr) {
350 if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(Val: tArg))
351 return emitOpError(message: "template argument has invalid type");
352 }
353 }
354
355 if (llvm::any_of(Range: getResultTypes(), P: llvm::IsaPred<ArrayType>)) {
356 return emitOpError() << "cannot return array type";
357 }
358
359 return success();
360}
361
362//===----------------------------------------------------------------------===//
363// ConstantOp
364//===----------------------------------------------------------------------===//
365
366LogicalResult emitc::ConstantOp::verify() {
367 Attribute value = getValueAttr();
368 if (failed(Result: verifyInitializationAttribute(op: getOperation(), value)))
369 return failure();
370 if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(Val&: value)) {
371 if (opaqueValue.getValue().empty())
372 return emitOpError() << "value must not be empty";
373 }
374 return success();
375}
376
377OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
378
379//===----------------------------------------------------------------------===//
380// ExpressionOp
381//===----------------------------------------------------------------------===//
382
383Operation *ExpressionOp::getRootOp() {
384 auto yieldOp = cast<YieldOp>(Val: getBody()->getTerminator());
385 Value yieldedValue = yieldOp.getResult();
386 return yieldedValue.getDefiningOp();
387}
388
389LogicalResult ExpressionOp::verify() {
390 Type resultType = getResult().getType();
391 Region &region = getRegion();
392
393 Block &body = region.front();
394
395 if (!body.mightHaveTerminator())
396 return emitOpError(message: "must yield a value at termination");
397
398 auto yield = cast<YieldOp>(Val: body.getTerminator());
399 Value yieldResult = yield.getResult();
400
401 if (!yieldResult)
402 return emitOpError(message: "must yield a value at termination");
403
404 Operation *rootOp = yieldResult.getDefiningOp();
405
406 if (!rootOp)
407 return emitOpError(message: "yielded value has no defining op");
408
409 if (rootOp->getParentOp() != getOperation())
410 return emitOpError(message: "yielded value not defined within expression");
411
412 Type yieldType = yieldResult.getType();
413
414 if (resultType != yieldType)
415 return emitOpError(message: "requires yielded type to match return type");
416
417 for (Operation &op : region.front().without_terminator()) {
418 if (!isa<emitc::CExpressionInterface>(Val: op))
419 return emitOpError(message: "contains an unsupported operation");
420 if (op.getNumResults() != 1)
421 return emitOpError(message: "requires exactly one result for each operation");
422 if (!op.getResult(idx: 0).hasOneUse())
423 return emitOpError(message: "requires exactly one use for each operation");
424 }
425
426 return success();
427}
428
429//===----------------------------------------------------------------------===//
430// ForOp
431//===----------------------------------------------------------------------===//
432
433void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
434 Value ub, Value step, BodyBuilderFn bodyBuilder) {
435 OpBuilder::InsertionGuard g(builder);
436 result.addOperands(newOperands: {lb, ub, step});
437 Type t = lb.getType();
438 Region *bodyRegion = result.addRegion();
439 Block *bodyBlock = builder.createBlock(parent: bodyRegion);
440 bodyBlock->addArgument(type: t, loc: result.location);
441
442 // Create the default terminator if the builder is not provided.
443 if (!bodyBuilder) {
444 ForOp::ensureTerminator(region&: *bodyRegion, builder, loc: result.location);
445 } else {
446 OpBuilder::InsertionGuard guard(builder);
447 builder.setInsertionPointToStart(bodyBlock);
448 bodyBuilder(builder, result.location, bodyBlock->getArgument(i: 0));
449 }
450}
451
452void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
453
454ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
455 Builder &builder = parser.getBuilder();
456 Type type;
457
458 OpAsmParser::Argument inductionVariable;
459 OpAsmParser::UnresolvedOperand lb, ub, step;
460
461 // Parse the induction variable followed by '='.
462 if (parser.parseOperand(result&: inductionVariable.ssaName) || parser.parseEqual() ||
463 // Parse loop bounds.
464 parser.parseOperand(result&: lb) || parser.parseKeyword(keyword: "to") ||
465 parser.parseOperand(result&: ub) || parser.parseKeyword(keyword: "step") ||
466 parser.parseOperand(result&: step))
467 return failure();
468
469 // Parse the optional initial iteration arguments.
470 SmallVector<OpAsmParser::Argument, 4> regionArgs;
471 regionArgs.push_back(Elt: inductionVariable);
472
473 // Parse optional type, else assume Index.
474 if (parser.parseOptionalColon())
475 type = builder.getIndexType();
476 else if (parser.parseType(result&: type))
477 return failure();
478
479 // Resolve input operands.
480 regionArgs.front().type = type;
481 if (parser.resolveOperand(operand: lb, type, result&: result.operands) ||
482 parser.resolveOperand(operand: ub, type, result&: result.operands) ||
483 parser.resolveOperand(operand: step, type, result&: result.operands))
484 return failure();
485
486 // Parse the body region.
487 Region *body = result.addRegion();
488 if (parser.parseRegion(region&: *body, arguments: regionArgs))
489 return failure();
490
491 ForOp::ensureTerminator(region&: *body, builder, loc: result.location);
492
493 // Parse the optional attribute list.
494 if (parser.parseOptionalAttrDict(result&: result.attributes))
495 return failure();
496
497 return success();
498}
499
500void ForOp::print(OpAsmPrinter &p) {
501 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
502 << getUpperBound() << " step " << getStep();
503
504 p << ' ';
505 if (Type t = getInductionVar().getType(); !t.isIndex())
506 p << " : " << t << ' ';
507 p.printRegion(blocks&: getRegion(),
508 /*printEntryBlockArgs=*/false,
509 /*printBlockTerminators=*/false);
510 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
511}
512
513LogicalResult ForOp::verifyRegions() {
514 // Check that the body defines as single block argument for the induction
515 // variable.
516 if (getInductionVar().getType() != getLowerBound().getType())
517 return emitOpError(
518 message: "expected induction variable to be same type as bounds and step");
519
520 return success();
521}
522
523//===----------------------------------------------------------------------===//
524// CallOp
525//===----------------------------------------------------------------------===//
526
527LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
528 // Check that the callee attribute was specified.
529 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>(name: "callee");
530 if (!fnAttr)
531 return emitOpError(message: "requires a 'callee' symbol reference attribute");
532 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(from: *this, symbol: fnAttr);
533 if (!fn)
534 return emitOpError() << "'" << fnAttr.getValue()
535 << "' does not reference a valid function";
536
537 // Verify that the operand and result types match the callee.
538 auto fnType = fn.getFunctionType();
539 if (fnType.getNumInputs() != getNumOperands())
540 return emitOpError(message: "incorrect number of operands for callee");
541
542 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
543 if (getOperand(i).getType() != fnType.getInput(i))
544 return emitOpError(message: "operand type mismatch: expected operand type ")
545 << fnType.getInput(i) << ", but provided "
546 << getOperand(i).getType() << " for operand number " << i;
547
548 if (fnType.getNumResults() != getNumResults())
549 return emitOpError(message: "incorrect number of results for callee");
550
551 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
552 if (getResult(i).getType() != fnType.getResult(i)) {
553 auto diag = emitOpError(message: "result type mismatch at index ") << i;
554 diag.attachNote() << " op result types: " << getResultTypes();
555 diag.attachNote() << "function result types: " << fnType.getResults();
556 return diag;
557 }
558
559 return success();
560}
561
562FunctionType CallOp::getCalleeType() {
563 return FunctionType::get(context: getContext(), inputs: getOperandTypes(), results: getResultTypes());
564}
565
566//===----------------------------------------------------------------------===//
567// DeclareFuncOp
568//===----------------------------------------------------------------------===//
569
570LogicalResult
571DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
572 // Check that the sym_name attribute was specified.
573 auto fnAttr = getSymNameAttr();
574 if (!fnAttr)
575 return emitOpError(message: "requires a 'sym_name' symbol reference attribute");
576 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(from: *this, symbol: fnAttr);
577 if (!fn)
578 return emitOpError() << "'" << fnAttr.getValue()
579 << "' does not reference a valid function";
580
581 return success();
582}
583
584//===----------------------------------------------------------------------===//
585// FuncOp
586//===----------------------------------------------------------------------===//
587
588void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
589 FunctionType type, ArrayRef<NamedAttribute> attrs,
590 ArrayRef<DictionaryAttr> argAttrs) {
591 state.addAttribute(name: SymbolTable::getSymbolAttrName(),
592 attr: builder.getStringAttr(bytes: name));
593 state.addAttribute(name: getFunctionTypeAttrName(name: state.name), attr: TypeAttr::get(type));
594 state.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
595 state.addRegion();
596
597 if (argAttrs.empty())
598 return;
599 assert(type.getNumInputs() == argAttrs.size());
600 call_interface_impl::addArgAndResultAttrs(
601 builder, result&: state, argAttrs, /*resultAttrs=*/{},
602 argAttrsName: getArgAttrsAttrName(name: state.name), resAttrsName: getResAttrsAttrName(name: state.name));
603}
604
605ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
606 auto buildFuncType =
607 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
608 function_interface_impl::VariadicFlag,
609 std::string &) { return builder.getFunctionType(inputs: argTypes, results); };
610
611 return function_interface_impl::parseFunctionOp(
612 parser, result, /*allowVariadic=*/false,
613 typeAttrName: getFunctionTypeAttrName(name: result.name), funcTypeBuilder: buildFuncType,
614 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
615}
616
617void FuncOp::print(OpAsmPrinter &p) {
618 function_interface_impl::printFunctionOp(
619 p, op: *this, /*isVariadic=*/false, typeAttrName: getFunctionTypeAttrName(),
620 argAttrsName: getArgAttrsAttrName(), resAttrsName: getResAttrsAttrName());
621}
622
623LogicalResult FuncOp::verify() {
624 if (llvm::any_of(Range: getArgumentTypes(), P: llvm::IsaPred<LValueType>)) {
625 return emitOpError(message: "cannot have lvalue type as argument");
626 }
627
628 if (getNumResults() > 1)
629 return emitOpError(message: "requires zero or exactly one result, but has ")
630 << getNumResults();
631
632 if (getNumResults() == 1 && isa<ArrayType>(Val: getResultTypes()[0]))
633 return emitOpError(message: "cannot return array type");
634
635 return success();
636}
637
638//===----------------------------------------------------------------------===//
639// ReturnOp
640//===----------------------------------------------------------------------===//
641
642LogicalResult ReturnOp::verify() {
643 auto function = cast<FuncOp>(Val: (*this)->getParentOp());
644
645 // The operand number and types must match the function signature.
646 if (getNumOperands() != function.getNumResults())
647 return emitOpError(message: "has ")
648 << getNumOperands() << " operands, but enclosing function (@"
649 << function.getName() << ") returns " << function.getNumResults();
650
651 if (function.getNumResults() == 1)
652 if (getOperand().getType() != function.getResultTypes()[0])
653 return emitError() << "type of the return operand ("
654 << getOperand().getType()
655 << ") doesn't match function result type ("
656 << function.getResultTypes()[0] << ")"
657 << " in function @" << function.getName();
658 return success();
659}
660
661//===----------------------------------------------------------------------===//
662// IfOp
663//===----------------------------------------------------------------------===//
664
665void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
666 bool addThenBlock, bool addElseBlock) {
667 assert((!addElseBlock || addThenBlock) &&
668 "must not create else block w/o then block");
669 result.addOperands(newOperands: cond);
670
671 // Add regions and blocks.
672 OpBuilder::InsertionGuard guard(builder);
673 Region *thenRegion = result.addRegion();
674 if (addThenBlock)
675 builder.createBlock(parent: thenRegion);
676 Region *elseRegion = result.addRegion();
677 if (addElseBlock)
678 builder.createBlock(parent: elseRegion);
679}
680
681void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
682 bool withElseRegion) {
683 result.addOperands(newOperands: cond);
684
685 // Build then region.
686 OpBuilder::InsertionGuard guard(builder);
687 Region *thenRegion = result.addRegion();
688 builder.createBlock(parent: thenRegion);
689
690 // Build else region.
691 Region *elseRegion = result.addRegion();
692 if (withElseRegion) {
693 builder.createBlock(parent: elseRegion);
694 }
695}
696
697void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
698 function_ref<void(OpBuilder &, Location)> thenBuilder,
699 function_ref<void(OpBuilder &, Location)> elseBuilder) {
700 assert(thenBuilder && "the builder callback for 'then' must be present");
701 result.addOperands(newOperands: cond);
702
703 // Build then region.
704 OpBuilder::InsertionGuard guard(builder);
705 Region *thenRegion = result.addRegion();
706 builder.createBlock(parent: thenRegion);
707 thenBuilder(builder, result.location);
708
709 // Build else region.
710 Region *elseRegion = result.addRegion();
711 if (elseBuilder) {
712 builder.createBlock(parent: elseRegion);
713 elseBuilder(builder, result.location);
714 }
715}
716
717ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
718 // Create the regions for 'then'.
719 result.regions.reserve(N: 2);
720 Region *thenRegion = result.addRegion();
721 Region *elseRegion = result.addRegion();
722
723 Builder &builder = parser.getBuilder();
724 OpAsmParser::UnresolvedOperand cond;
725 Type i1Type = builder.getIntegerType(width: 1);
726 if (parser.parseOperand(result&: cond) ||
727 parser.resolveOperand(operand: cond, type: i1Type, result&: result.operands))
728 return failure();
729 // Parse the 'then' region.
730 if (parser.parseRegion(region&: *thenRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
731 return failure();
732 IfOp::ensureTerminator(region&: *thenRegion, builder&: parser.getBuilder(), loc: result.location);
733
734 // If we find an 'else' keyword then parse the 'else' region.
735 if (!parser.parseOptionalKeyword(keyword: "else")) {
736 if (parser.parseRegion(region&: *elseRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
737 return failure();
738 IfOp::ensureTerminator(region&: *elseRegion, builder&: parser.getBuilder(), loc: result.location);
739 }
740
741 // Parse the optional attribute list.
742 if (parser.parseOptionalAttrDict(result&: result.attributes))
743 return failure();
744 return success();
745}
746
747void IfOp::print(OpAsmPrinter &p) {
748 bool printBlockTerminators = false;
749
750 p << " " << getCondition();
751 p << ' ';
752 p.printRegion(blocks&: getThenRegion(),
753 /*printEntryBlockArgs=*/false,
754 /*printBlockTerminators=*/printBlockTerminators);
755
756 // Print the 'else' regions if it exists and has a block.
757 Region &elseRegion = getElseRegion();
758 if (!elseRegion.empty()) {
759 p << " else ";
760 p.printRegion(blocks&: elseRegion,
761 /*printEntryBlockArgs=*/false,
762 /*printBlockTerminators=*/printBlockTerminators);
763 }
764
765 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
766}
767
768/// Given the region at `index`, or the parent operation if `index` is None,
769/// return the successor regions. These are the regions that may be selected
770/// during the flow of control. `operands` is a set of optional attributes
771/// that correspond to a constant value for each operand, or null if that
772/// operand is not a constant.
773void IfOp::getSuccessorRegions(RegionBranchPoint point,
774 SmallVectorImpl<RegionSuccessor> &regions) {
775 // The `then` and the `else` region branch back to the parent operation.
776 if (!point.isParent()) {
777 regions.push_back(Elt: RegionSuccessor());
778 return;
779 }
780
781 regions.push_back(Elt: RegionSuccessor(&getThenRegion()));
782
783 // Don't consider the else region if it is empty.
784 Region *elseRegion = &this->getElseRegion();
785 if (elseRegion->empty())
786 regions.push_back(Elt: RegionSuccessor());
787 else
788 regions.push_back(Elt: RegionSuccessor(elseRegion));
789}
790
791void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
792 SmallVectorImpl<RegionSuccessor> &regions) {
793 FoldAdaptor adaptor(operands, *this);
794 auto boolAttr = dyn_cast_or_null<BoolAttr>(Val: adaptor.getCondition());
795 if (!boolAttr || boolAttr.getValue())
796 regions.emplace_back(Args: &getThenRegion());
797
798 // If the else region is empty, execution continues after the parent op.
799 if (!boolAttr || !boolAttr.getValue()) {
800 if (!getElseRegion().empty())
801 regions.emplace_back(Args: &getElseRegion());
802 else
803 regions.emplace_back();
804 }
805}
806
807void IfOp::getRegionInvocationBounds(
808 ArrayRef<Attribute> operands,
809 SmallVectorImpl<InvocationBounds> &invocationBounds) {
810 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(Val: operands[0])) {
811 // If the condition is known, then one region is known to be executed once
812 // and the other zero times.
813 invocationBounds.emplace_back(Args: 0, Args: cond.getValue() ? 1 : 0);
814 invocationBounds.emplace_back(Args: 0, Args: cond.getValue() ? 0 : 1);
815 } else {
816 // Non-constant condition. Each region may be executed 0 or 1 times.
817 invocationBounds.assign(NumElts: 2, Elt: {0, 1});
818 }
819}
820
821//===----------------------------------------------------------------------===//
822// IncludeOp
823//===----------------------------------------------------------------------===//
824
825void IncludeOp::print(OpAsmPrinter &p) {
826 bool standardInclude = getIsStandardInclude();
827
828 p << " ";
829 if (standardInclude)
830 p << "<";
831 p << "\"" << getInclude() << "\"";
832 if (standardInclude)
833 p << ">";
834}
835
836ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
837 bool standardInclude = !parser.parseOptionalLess();
838
839 StringAttr include;
840 OptionalParseResult includeParseResult =
841 parser.parseOptionalAttribute(result&: include, attrName: "include", attrs&: result.attributes);
842 if (!includeParseResult.has_value())
843 return parser.emitError(loc: parser.getNameLoc()) << "expected string attribute";
844
845 if (standardInclude && parser.parseOptionalGreater())
846 return parser.emitError(loc: parser.getNameLoc())
847 << "expected trailing '>' for standard include";
848
849 if (standardInclude)
850 result.addAttribute(name: "is_standard_include",
851 attr: UnitAttr::get(context: parser.getContext()));
852
853 return success();
854}
855
856//===----------------------------------------------------------------------===//
857// LiteralOp
858//===----------------------------------------------------------------------===//
859
860/// The literal op requires a non-empty value.
861LogicalResult emitc::LiteralOp::verify() {
862 if (getValue().empty())
863 return emitOpError() << "value must not be empty";
864 return success();
865}
866//===----------------------------------------------------------------------===//
867// SubOp
868//===----------------------------------------------------------------------===//
869
870LogicalResult SubOp::verify() {
871 Type lhsType = getLhs().getType();
872 Type rhsType = getRhs().getType();
873 Type resultType = getResult().getType();
874
875 if (isa<emitc::PointerType>(Val: rhsType) && !isa<emitc::PointerType>(Val: lhsType))
876 return emitOpError(message: "rhs can only be a pointer if lhs is a pointer");
877
878 if (isa<emitc::PointerType>(Val: lhsType) &&
879 !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(Val: rhsType))
880 return emitOpError(message: "requires that rhs is an integer, pointer or of opaque "
881 "type if lhs is a pointer");
882
883 if (isa<emitc::PointerType>(Val: lhsType) && isa<emitc::PointerType>(Val: rhsType) &&
884 !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(Val: resultType))
885 return emitOpError(message: "requires that the result is an integer, ptrdiff_t or "
886 "of opaque type if lhs and rhs are pointers");
887 return success();
888}
889
890//===----------------------------------------------------------------------===//
891// VariableOp
892//===----------------------------------------------------------------------===//
893
894LogicalResult emitc::VariableOp::verify() {
895 return verifyInitializationAttribute(op: getOperation(), value: getValueAttr());
896}
897
898//===----------------------------------------------------------------------===//
899// YieldOp
900//===----------------------------------------------------------------------===//
901
902LogicalResult emitc::YieldOp::verify() {
903 Value result = getResult();
904 Operation *containingOp = getOperation()->getParentOp();
905
906 if (result && containingOp->getNumResults() != 1)
907 return emitOpError() << "yields a value not returned by parent";
908
909 if (!result && containingOp->getNumResults() != 0)
910 return emitOpError() << "does not yield a value to be returned by parent";
911
912 return success();
913}
914
915//===----------------------------------------------------------------------===//
916// SubscriptOp
917//===----------------------------------------------------------------------===//
918
919LogicalResult emitc::SubscriptOp::verify() {
920 // Checks for array operand.
921 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(Val: getValue().getType())) {
922 // Check number of indices.
923 if (getIndices().size() != (size_t)arrayType.getRank()) {
924 return emitOpError() << "on array operand requires number of indices ("
925 << getIndices().size()
926 << ") to match the rank of the array type ("
927 << arrayType.getRank() << ")";
928 }
929 // Check types of index operands.
930 for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
931 Type type = getIndices()[i].getType();
932 if (!isIntegerIndexOrOpaqueType(type)) {
933 return emitOpError() << "on array operand requires index operand " << i
934 << " to be integer-like, but got " << type;
935 }
936 }
937 // Check element type.
938 Type elementType = arrayType.getElementType();
939 Type resultType = getType().getValueType();
940 if (elementType != resultType) {
941 return emitOpError() << "on array operand requires element type ("
942 << elementType << ") and result type (" << resultType
943 << ") to match";
944 }
945 return success();
946 }
947
948 // Checks for pointer operand.
949 if (auto pointerType =
950 llvm::dyn_cast<emitc::PointerType>(Val: getValue().getType())) {
951 // Check number of indices.
952 if (getIndices().size() != 1) {
953 return emitOpError()
954 << "on pointer operand requires one index operand, but got "
955 << getIndices().size();
956 }
957 // Check types of index operand.
958 Type type = getIndices()[0].getType();
959 if (!isIntegerIndexOrOpaqueType(type)) {
960 return emitOpError() << "on pointer operand requires index operand to be "
961 "integer-like, but got "
962 << type;
963 }
964 // Check pointee type.
965 Type pointeeType = pointerType.getPointee();
966 Type resultType = getType().getValueType();
967 if (pointeeType != resultType) {
968 return emitOpError() << "on pointer operand requires pointee type ("
969 << pointeeType << ") and result type (" << resultType
970 << ") to match";
971 }
972 return success();
973 }
974
975 // The operand has opaque type, so we can't assume anything about the number
976 // or types of index operands.
977 return success();
978}
979
980//===----------------------------------------------------------------------===//
981// VerbatimOp
982//===----------------------------------------------------------------------===//
983
984LogicalResult emitc::VerbatimOp::verify() {
985 auto errorCallback = [&]() -> InFlightDiagnostic {
986 return this->emitOpError();
987 };
988 FailureOr<SmallVector<ReplacementItem>> fmt =
989 ::parseFormatString(toParse: getValue(), fmtArgs: getFmtArgs(), emitError: errorCallback);
990 if (failed(Result: fmt))
991 return failure();
992
993 size_t numPlaceholders = llvm::count_if(Range&: *fmt, P: [](ReplacementItem &item) {
994 return std::holds_alternative<Placeholder>(v: item);
995 });
996
997 if (numPlaceholders != getFmtArgs().size()) {
998 return emitOpError()
999 << "requires operands for each placeholder in the format string";
1000 }
1001 return success();
1002}
1003
1004FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1005 // Error checking is done in verify.
1006 return ::parseFormatString(toParse: getValue(), fmtArgs: getFmtArgs());
1007}
1008
1009//===----------------------------------------------------------------------===//
1010// EmitC Enums
1011//===----------------------------------------------------------------------===//
1012
1013#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
1014
1015//===----------------------------------------------------------------------===//
1016// EmitC Attributes
1017//===----------------------------------------------------------------------===//
1018
1019#define GET_ATTRDEF_CLASSES
1020#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
1021
1022//===----------------------------------------------------------------------===//
1023// EmitC Types
1024//===----------------------------------------------------------------------===//
1025
1026#define GET_TYPEDEF_CLASSES
1027#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
1028
1029//===----------------------------------------------------------------------===//
1030// ArrayType
1031//===----------------------------------------------------------------------===//
1032
1033Type emitc::ArrayType::parse(AsmParser &parser) {
1034 if (parser.parseLess())
1035 return Type();
1036
1037 SmallVector<int64_t, 4> dimensions;
1038 if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
1039 /*withTrailingX=*/true))
1040 return Type();
1041 // Parse the element type.
1042 auto typeLoc = parser.getCurrentLocation();
1043 Type elementType;
1044 if (parser.parseType(result&: elementType))
1045 return Type();
1046
1047 // Check that array is formed from allowed types.
1048 if (!isValidElementType(type: elementType))
1049 return parser.emitError(loc: typeLoc, message: "invalid array element type '")
1050 << elementType << "'",
1051 Type();
1052 if (parser.parseGreater())
1053 return Type();
1054 return parser.getChecked<ArrayType>(params&: dimensions, params&: elementType);
1055}
1056
1057void emitc::ArrayType::print(AsmPrinter &printer) const {
1058 printer << "<";
1059 for (int64_t dim : getShape()) {
1060 printer << dim << 'x';
1061 }
1062 printer.printType(type: getElementType());
1063 printer << ">";
1064}
1065
1066LogicalResult emitc::ArrayType::verify(
1067 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
1068 ::llvm::ArrayRef<int64_t> shape, Type elementType) {
1069 if (shape.empty())
1070 return emitError() << "shape must not be empty";
1071
1072 for (int64_t dim : shape) {
1073 if (dim < 0)
1074 return emitError() << "dimensions must have non-negative size";
1075 }
1076
1077 if (!elementType)
1078 return emitError() << "element type must not be none";
1079
1080 if (!isValidElementType(type: elementType))
1081 return emitError() << "invalid array element type";
1082
1083 return success();
1084}
1085
1086emitc::ArrayType
1087emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
1088 Type elementType) const {
1089 if (!shape)
1090 return emitc::ArrayType::get(shape: getShape(), elementType);
1091 return emitc::ArrayType::get(shape: *shape, elementType);
1092}
1093
1094//===----------------------------------------------------------------------===//
1095// LValueType
1096//===----------------------------------------------------------------------===//
1097
1098LogicalResult mlir::emitc::LValueType::verify(
1099 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1100 mlir::Type value) {
1101 // Check that the wrapped type is valid. This especially forbids nested
1102 // lvalue types.
1103 if (!isSupportedEmitCType(type: value))
1104 return emitError()
1105 << "!emitc.lvalue must wrap supported emitc type, but got " << value;
1106
1107 if (llvm::isa<emitc::ArrayType>(Val: value))
1108 return emitError() << "!emitc.lvalue cannot wrap !emitc.array type";
1109
1110 return success();
1111}
1112
1113//===----------------------------------------------------------------------===//
1114// OpaqueType
1115//===----------------------------------------------------------------------===//
1116
1117LogicalResult mlir::emitc::OpaqueType::verify(
1118 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1119 llvm::StringRef value) {
1120 if (value.empty()) {
1121 return emitError() << "expected non empty string in !emitc.opaque type";
1122 }
1123 if (value.back() == '*') {
1124 return emitError() << "pointer not allowed as outer type with "
1125 "!emitc.opaque, use !emitc.ptr instead";
1126 }
1127 return success();
1128}
1129
1130//===----------------------------------------------------------------------===//
1131// PointerType
1132//===----------------------------------------------------------------------===//
1133
1134LogicalResult mlir::emitc::PointerType::verify(
1135 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, Type value) {
1136 if (llvm::isa<emitc::LValueType>(Val: value))
1137 return emitError() << "pointers to lvalues are not allowed";
1138
1139 return success();
1140}
1141
1142//===----------------------------------------------------------------------===//
1143// GlobalOp
1144//===----------------------------------------------------------------------===//
1145static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1146 TypeAttr type,
1147 Attribute initialValue) {
1148 p << type;
1149 if (initialValue) {
1150 p << " = ";
1151 p.printAttributeWithoutType(attr: initialValue);
1152 }
1153}
1154
1155static Type getInitializerTypeForGlobal(Type type) {
1156 if (auto array = llvm::dyn_cast<ArrayType>(Val&: type))
1157 return RankedTensorType::get(shape: array.getShape(), elementType: array.getElementType());
1158 return type;
1159}
1160
1161static ParseResult
1162parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1163 Attribute &initialValue) {
1164 Type type;
1165 if (parser.parseType(result&: type))
1166 return failure();
1167
1168 typeAttr = TypeAttr::get(type);
1169
1170 if (parser.parseOptionalEqual())
1171 return success();
1172
1173 if (parser.parseAttribute(result&: initialValue, type: getInitializerTypeForGlobal(type)))
1174 return failure();
1175
1176 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1177 Val: initialValue))
1178 return parser.emitError(loc: parser.getNameLoc())
1179 << "initial value should be a integer, float, elements or opaque "
1180 "attribute";
1181 return success();
1182}
1183
1184LogicalResult GlobalOp::verify() {
1185 if (!isSupportedEmitCType(type: getType())) {
1186 return emitOpError(message: "expected valid emitc type");
1187 }
1188 if (getInitialValue().has_value()) {
1189 Attribute initValue = getInitialValue().value();
1190 // Check that the type of the initial value is compatible with the type of
1191 // the global variable.
1192 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(Val&: initValue)) {
1193 auto arrayType = llvm::dyn_cast<ArrayType>(Val: getType());
1194 if (!arrayType)
1195 return emitOpError(message: "expected array type, but got ") << getType();
1196
1197 Type initType = elementsAttr.getType();
1198 Type tensorType = getInitializerTypeForGlobal(type: getType());
1199 if (initType != tensorType) {
1200 return emitOpError(message: "initial value expected to be of type ")
1201 << getType() << ", but was of type " << initType;
1202 }
1203 } else if (auto intAttr = dyn_cast<IntegerAttr>(Val&: initValue)) {
1204 if (intAttr.getType() != getType()) {
1205 return emitOpError(message: "initial value expected to be of type ")
1206 << getType() << ", but was of type " << intAttr.getType();
1207 }
1208 } else if (auto floatAttr = dyn_cast<FloatAttr>(Val&: initValue)) {
1209 if (floatAttr.getType() != getType()) {
1210 return emitOpError(message: "initial value expected to be of type ")
1211 << getType() << ", but was of type " << floatAttr.getType();
1212 }
1213 } else if (!isa<emitc::OpaqueAttr>(Val: initValue)) {
1214 return emitOpError(message: "initial value should be a integer, float, elements "
1215 "or opaque attribute, but got ")
1216 << initValue;
1217 }
1218 }
1219 if (getStaticSpecifier() && getExternSpecifier()) {
1220 return emitOpError(message: "cannot have both static and extern specifiers");
1221 }
1222 return success();
1223}
1224
1225//===----------------------------------------------------------------------===//
1226// GetGlobalOp
1227//===----------------------------------------------------------------------===//
1228
1229LogicalResult
1230GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1231 // Verify that the type matches the type of the global variable.
1232 auto global =
1233 symbolTable.lookupNearestSymbolFrom<GlobalOp>(from: *this, symbol: getNameAttr());
1234 if (!global)
1235 return emitOpError(message: "'")
1236 << getName() << "' does not reference a valid emitc.global";
1237
1238 Type resultType = getResult().getType();
1239 Type globalType = global.getType();
1240
1241 // global has array type
1242 if (llvm::isa<ArrayType>(Val: globalType)) {
1243 if (globalType != resultType)
1244 return emitOpError(message: "on array type expects result type ")
1245 << resultType << " to match type " << globalType
1246 << " of the global @" << getName();
1247 return success();
1248 }
1249
1250 // global has non-array type
1251 auto lvalueType = dyn_cast<LValueType>(Val&: resultType);
1252 if (!lvalueType || lvalueType.getValueType() != globalType)
1253 return emitOpError(message: "on non-array type expects result inner type ")
1254 << lvalueType.getValueType() << " to match type " << globalType
1255 << " of the global @" << getName();
1256 return success();
1257}
1258
1259//===----------------------------------------------------------------------===//
1260// SwitchOp
1261//===----------------------------------------------------------------------===//
1262
1263/// Parse the case regions and values.
1264static ParseResult
1265parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
1266 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1267 SmallVector<int64_t> caseValues;
1268 while (succeeded(Result: parser.parseOptionalKeyword(keyword: "case"))) {
1269 int64_t value;
1270 Region &region = *caseRegions.emplace_back(Args: std::make_unique<Region>());
1271 if (parser.parseInteger(result&: value) ||
1272 parser.parseRegion(region, /*arguments=*/{}))
1273 return failure();
1274 caseValues.push_back(Elt: value);
1275 }
1276 cases = parser.getBuilder().getDenseI64ArrayAttr(values: caseValues);
1277 return success();
1278}
1279
1280/// Print the case regions and values.
1281static void printSwitchCases(OpAsmPrinter &p, Operation *op,
1282 DenseI64ArrayAttr cases, RegionRange caseRegions) {
1283 for (auto [value, region] : llvm::zip(t: cases.asArrayRef(), u&: caseRegions)) {
1284 p.printNewline();
1285 p << "case " << value << ' ';
1286 p.printRegion(blocks&: *region, /*printEntryBlockArgs=*/false);
1287 }
1288}
1289
1290static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1291 const Twine &name) {
1292 auto yield = dyn_cast<emitc::YieldOp>(Val&: region.front().back());
1293 if (!yield)
1294 return op.emitOpError(message: "expected region to end with emitc.yield, but got ")
1295 << region.front().back().getName();
1296
1297 if (yield.getNumOperands() != 0) {
1298 return (op.emitOpError(message: "expected each region to return ")
1299 << "0 values, but " << name << " returns "
1300 << yield.getNumOperands())
1301 .attachNote(noteLoc: yield.getLoc())
1302 << "see yield operation here";
1303 }
1304
1305 return success();
1306}
1307
1308LogicalResult emitc::SwitchOp::verify() {
1309 if (!isIntegerIndexOrOpaqueType(type: getArg().getType()))
1310 return emitOpError(message: "unsupported type ") << getArg().getType();
1311
1312 if (getCases().size() != getCaseRegions().size()) {
1313 return emitOpError(message: "has ")
1314 << getCaseRegions().size() << " case regions but "
1315 << getCases().size() << " case values";
1316 }
1317
1318 DenseSet<int64_t> valueSet;
1319 for (int64_t value : getCases())
1320 if (!valueSet.insert(V: value).second)
1321 return emitOpError(message: "has duplicate case value: ") << value;
1322
1323 if (failed(Result: verifyRegion(op: *this, region&: getDefaultRegion(), name: "default region")))
1324 return failure();
1325
1326 for (auto [idx, caseRegion] : llvm::enumerate(First: getCaseRegions()))
1327 if (failed(Result: verifyRegion(op: *this, region&: caseRegion, name: "case region #" + Twine(idx))))
1328 return failure();
1329
1330 return success();
1331}
1332
1333unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1334
1335Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1336
1337Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1338 assert(idx < getNumCases() && "case index out-of-bounds");
1339 return getCaseRegions()[idx].front();
1340}
1341
1342void SwitchOp::getSuccessorRegions(
1343 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1344 llvm::append_range(C&: successors, R: getRegions());
1345}
1346
1347void SwitchOp::getEntrySuccessorRegions(
1348 ArrayRef<Attribute> operands,
1349 SmallVectorImpl<RegionSuccessor> &successors) {
1350 FoldAdaptor adaptor(operands, *this);
1351
1352 // If a constant was not provided, all regions are possible successors.
1353 auto arg = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getArg());
1354 if (!arg) {
1355 llvm::append_range(C&: successors, R: getRegions());
1356 return;
1357 }
1358
1359 // Otherwise, try to find a case with a matching value. If not, the
1360 // default region is the only successor.
1361 for (auto [caseValue, caseRegion] : llvm::zip(t: getCases(), u: getCaseRegions())) {
1362 if (caseValue == arg.getInt()) {
1363 successors.emplace_back(Args: &caseRegion);
1364 return;
1365 }
1366 }
1367 successors.emplace_back(Args: &getDefaultRegion());
1368}
1369
1370void SwitchOp::getRegionInvocationBounds(
1371 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1372 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(Val: operands.front());
1373 if (!operandValue) {
1374 // All regions are invoked at most once.
1375 bounds.append(NumInputs: getNumRegions(), Elt: InvocationBounds(/*lb=*/0, /*ub=*/1));
1376 return;
1377 }
1378
1379 unsigned liveIndex = getNumRegions() - 1;
1380 const auto *iteratorToInt = llvm::find(Range: getCases(), Val: operandValue.getInt());
1381
1382 liveIndex = iteratorToInt != getCases().end()
1383 ? std::distance(first: getCases().begin(), last: iteratorToInt)
1384 : liveIndex;
1385
1386 for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1387 ++regIndex)
1388 bounds.emplace_back(/*lb=*/Args: 0, /*ub=*/Args: regIndex == liveIndex);
1389}
1390
1391//===----------------------------------------------------------------------===//
1392// FileOp
1393//===----------------------------------------------------------------------===//
1394void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
1395 state.addRegion()->emplaceBlock();
1396 state.attributes.push_back(
1397 newAttribute: builder.getNamedAttr(name: "id", val: builder.getStringAttr(bytes: id)));
1398}
1399
1400//===----------------------------------------------------------------------===//
1401// FieldOp
1402//===----------------------------------------------------------------------===//
1403LogicalResult FieldOp::verify() {
1404 if (!isSupportedEmitCType(type: getType()))
1405 return emitOpError(message: "expected valid emitc type");
1406
1407 Operation *parentOp = getOperation()->getParentOp();
1408 if (!parentOp || !isa<emitc::ClassOp>(Val: parentOp))
1409 return emitOpError(message: "field must be nested within an emitc.class operation");
1410
1411 StringAttr symName = getSymNameAttr();
1412 if (!symName || symName.getValue().empty())
1413 return emitOpError(message: "field must have a non-empty symbol name");
1414
1415 if (!getAttrs())
1416 return success();
1417
1418 return success();
1419}
1420
1421//===----------------------------------------------------------------------===//
1422// GetFieldOp
1423//===----------------------------------------------------------------------===//
1424LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1425 mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
1426 FieldOp fieldOp =
1427 symbolTable.lookupNearestSymbolFrom<FieldOp>(from: *this, symbol: fieldNameAttr);
1428 if (!fieldOp)
1429 return emitOpError(message: "field '")
1430 << fieldNameAttr << "' not found in the class";
1431
1432 Type getFieldResultType = getResult().getType();
1433 Type fieldType = fieldOp.getType();
1434
1435 if (fieldType != getFieldResultType)
1436 return emitOpError(message: "result type ")
1437 << getFieldResultType << " does not match field '" << fieldNameAttr
1438 << "' type " << fieldType;
1439
1440 return success();
1441}
1442
1443//===----------------------------------------------------------------------===//
1444// TableGen'd op method definitions
1445//===----------------------------------------------------------------------===//
1446
1447#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.cpp.inc"
1448
1449#define GET_OP_CLASSES
1450#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
1451

source code of mlir/lib/Dialect/EmitC/IR/EmitC.cpp