1//===- EmitC.cpp - EmitC Dialect ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/EmitC/IR/EmitC.h"
10#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Types.h"
17#include "mlir/Interfaces/FunctionImplementation.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Casting.h"
22#include "llvm/Support/FormatVariadic.h"
23
24using namespace mlir;
25using namespace mlir::emitc;
26
27#include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
28
29//===----------------------------------------------------------------------===//
30// EmitCDialect
31//===----------------------------------------------------------------------===//
32
33void EmitCDialect::initialize() {
34 addOperations<
35#define GET_OP_LIST
36#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
37 >();
38 addTypes<
39#define GET_TYPEDEF_LIST
40#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
41 >();
42 addAttributes<
43#define GET_ATTRDEF_LIST
44#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
45 >();
46}
47
48/// Materialize a single constant operation from a given attribute value with
49/// the desired resultant type.
50Operation *EmitCDialect::materializeConstant(OpBuilder &builder,
51 Attribute value, Type type,
52 Location loc) {
53 return builder.create<emitc::ConstantOp>(loc, type, value);
54}
55
56/// Default callback for builders of ops carrying a region. Inserts a yield
57/// without arguments.
58void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
59 builder.create<emitc::YieldOp>(loc);
60}
61
62bool mlir::emitc::isSupportedEmitCType(Type type) {
63 if (llvm::isa<emitc::OpaqueType>(type))
64 return true;
65 if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
66 return isSupportedEmitCType(ptrType.getPointee());
67 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
68 auto elemType = arrayType.getElementType();
69 return !llvm::isa<emitc::ArrayType>(elemType) &&
70 isSupportedEmitCType(elemType);
71 }
72 if (type.isIndex() || emitc::isPointerWideType(type))
73 return true;
74 if (llvm::isa<IntegerType>(Val: type))
75 return isSupportedIntegerType(type);
76 if (llvm::isa<FloatType>(Val: type))
77 return isSupportedFloatType(type);
78 if (auto tensorType = llvm::dyn_cast<TensorType>(Val&: type)) {
79 if (!tensorType.hasStaticShape()) {
80 return false;
81 }
82 auto elemType = tensorType.getElementType();
83 if (llvm::isa<emitc::ArrayType>(elemType)) {
84 return false;
85 }
86 return isSupportedEmitCType(type: elemType);
87 }
88 if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
89 return llvm::all_of(tupleType.getTypes(), [](Type type) {
90 return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
91 });
92 }
93 return false;
94}
95
96bool mlir::emitc::isSupportedIntegerType(Type type) {
97 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
98 switch (intType.getWidth()) {
99 case 1:
100 case 8:
101 case 16:
102 case 32:
103 case 64:
104 return true;
105 default:
106 return false;
107 }
108 }
109 return false;
110}
111
112bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
113 return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
114 isSupportedIntegerType(type) || isPointerWideType(type);
115}
116
117bool mlir::emitc::isSupportedFloatType(Type type) {
118 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
119 switch (floatType.getWidth()) {
120 case 16: {
121 if (llvm::isa<Float16Type, BFloat16Type>(type))
122 return true;
123 return false;
124 }
125 case 32:
126 case 64:
127 return true;
128 default:
129 return false;
130 }
131 }
132 return false;
133}
134
135bool mlir::emitc::isPointerWideType(Type type) {
136 return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
137 type);
138}
139
140/// Check that the type of the initial value is compatible with the operations
141/// result type.
142static LogicalResult verifyInitializationAttribute(Operation *op,
143 Attribute value) {
144 assert(op->getNumResults() == 1 && "operation must have 1 result");
145
146 if (llvm::isa<emitc::OpaqueAttr>(value))
147 return success();
148
149 if (llvm::isa<StringAttr>(Val: value))
150 return op->emitOpError()
151 << "string attributes are not supported, use #emitc.opaque instead";
152
153 Type resultType = op->getResult(idx: 0).getType();
154 if (auto lType = dyn_cast<LValueType>(resultType))
155 resultType = lType.getValueType();
156 Type attrType = cast<TypedAttr>(value).getType();
157
158 if (isPointerWideType(type: resultType) && attrType.isIndex())
159 return success();
160
161 if (resultType != attrType)
162 return op->emitOpError()
163 << "requires attribute to either be an #emitc.opaque attribute or "
164 "it's type ("
165 << attrType << ") to match the op's result type (" << resultType
166 << ")";
167
168 return success();
169}
170
171/// Parse a format string and return a list of its parts.
172/// A part is either a StringRef that has to be printed as-is, or
173/// a Placeholder which requires printing the next operand of the VerbatimOp.
174/// In the format string, all `{}` are replaced by Placeholders, except if the
175/// `{` is escaped by `{{` - then it doesn't start a placeholder.
176template <class ArgType>
177FailureOr<SmallVector<ReplacementItem>>
178parseFormatString(StringRef toParse, ArgType fmtArgs,
179 std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>>
180 emitError = {}) {
181 SmallVector<ReplacementItem> items;
182
183 // If there are not operands, the format string is not interpreted.
184 if (fmtArgs.empty()) {
185 items.push_back(Elt: toParse);
186 return items;
187 }
188
189 while (!toParse.empty()) {
190 size_t idx = toParse.find(C: '{');
191 if (idx == StringRef::npos) {
192 // No '{'
193 items.push_back(Elt: toParse);
194 break;
195 }
196 if (idx > 0) {
197 // Take all chars excluding the '{'.
198 items.push_back(Elt: toParse.take_front(N: idx));
199 toParse = toParse.drop_front(N: idx);
200 continue;
201 }
202 if (toParse.size() < 2) {
203 return (*emitError)()
204 << "expected '}' after unescaped '{' at end of string";
205 }
206 // toParse contains at least two characters and starts with `{`.
207 char nextChar = toParse[1];
208 if (nextChar == '{') {
209 // Double '{{' -> '{' (escaping).
210 items.push_back(Elt: toParse.take_front(N: 1));
211 toParse = toParse.drop_front(N: 2);
212 continue;
213 }
214 if (nextChar == '}') {
215 items.push_back(Elt: Placeholder{});
216 toParse = toParse.drop_front(N: 2);
217 continue;
218 }
219
220 if (emitError.has_value()) {
221 return (*emitError)() << "expected '}' after unescaped '{'";
222 }
223 return failure();
224 }
225 return items;
226}
227
228//===----------------------------------------------------------------------===//
229// AddOp
230//===----------------------------------------------------------------------===//
231
232LogicalResult AddOp::verify() {
233 Type lhsType = getLhs().getType();
234 Type rhsType = getRhs().getType();
235
236 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
237 return emitOpError("requires that at most one operand is a pointer");
238
239 if ((isa<emitc::PointerType>(lhsType) &&
240 !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
241 (isa<emitc::PointerType>(rhsType) &&
242 !isa<IntegerType, emitc::OpaqueType>(lhsType)))
243 return emitOpError("requires that one operand is an integer or of opaque "
244 "type if the other is a pointer");
245
246 return success();
247}
248
249//===----------------------------------------------------------------------===//
250// ApplyOp
251//===----------------------------------------------------------------------===//
252
253LogicalResult ApplyOp::verify() {
254 StringRef applicableOperatorStr = getApplicableOperator();
255
256 // Applicable operator must not be empty.
257 if (applicableOperatorStr.empty())
258 return emitOpError("applicable operator must not be empty");
259
260 // Only `*` and `&` are supported.
261 if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
262 return emitOpError("applicable operator is illegal");
263
264 Type operandType = getOperand().getType();
265 Type resultType = getResult().getType();
266 if (applicableOperatorStr == "&") {
267 if (!llvm::isa<emitc::LValueType>(operandType))
268 return emitOpError("operand type must be an lvalue when applying `&`");
269 if (!llvm::isa<emitc::PointerType>(resultType))
270 return emitOpError("result type must be a pointer when applying `&`");
271 } else {
272 if (!llvm::isa<emitc::PointerType>(operandType))
273 return emitOpError("operand type must be a pointer when applying `*`");
274 }
275
276 return success();
277}
278
279//===----------------------------------------------------------------------===//
280// AssignOp
281//===----------------------------------------------------------------------===//
282
283/// The assign op requires that the assigned value's type matches the
284/// assigned-to variable type.
285LogicalResult emitc::AssignOp::verify() {
286 TypedValue<emitc::LValueType> variable = getVar();
287
288 if (!variable.getDefiningOp())
289 return emitOpError() << "cannot assign to block argument";
290
291 Type valueType = getValue().getType();
292 Type variableType = variable.getType().getValueType();
293 if (variableType != valueType)
294 return emitOpError() << "requires value's type (" << valueType
295 << ") to match variable's type (" << variableType
296 << ")\n variable: " << variable
297 << "\n value: " << getValue() << "\n";
298 return success();
299}
300
301//===----------------------------------------------------------------------===//
302// CastOp
303//===----------------------------------------------------------------------===//
304
305bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
306 Type input = inputs.front(), output = outputs.front();
307
308 if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
309 if (auto pointerType = dyn_cast<emitc::PointerType>(output)) {
310 return (arrayType.getElementType() == pointerType.getPointee()) &&
311 arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
312 }
313 return false;
314 }
315
316 return (
317 (emitc::isIntegerIndexOrOpaqueType(input) ||
318 emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
319 (emitc::isIntegerIndexOrOpaqueType(output) ||
320 emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
321}
322
323//===----------------------------------------------------------------------===//
324// CallOpaqueOp
325//===----------------------------------------------------------------------===//
326
327LogicalResult emitc::CallOpaqueOp::verify() {
328 // Callee must not be empty.
329 if (getCallee().empty())
330 return emitOpError("callee must not be empty");
331
332 if (std::optional<ArrayAttr> argsAttr = getArgs()) {
333 for (Attribute arg : *argsAttr) {
334 auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
335 if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
336 int64_t index = intAttr.getInt();
337 // Args with elements of type index must be in range
338 // [0..operands.size).
339 if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
340 return emitOpError("index argument is out of range");
341
342 // Args with elements of type ArrayAttr must have a type.
343 } else if (llvm::isa<ArrayAttr>(
344 arg) /*&& llvm::isa<NoneType>(arg.getType())*/) {
345 // FIXME: Array attributes never have types
346 return emitOpError("array argument has no type");
347 }
348 }
349 }
350
351 if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
352 for (Attribute tArg : *templateArgsAttr) {
353 if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
354 return emitOpError("template argument has invalid type");
355 }
356 }
357
358 if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
359 return emitOpError() << "cannot return array type";
360 }
361
362 return success();
363}
364
365//===----------------------------------------------------------------------===//
366// ConstantOp
367//===----------------------------------------------------------------------===//
368
369LogicalResult emitc::ConstantOp::verify() {
370 Attribute value = getValueAttr();
371 if (failed(verifyInitializationAttribute(getOperation(), value)))
372 return failure();
373 if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
374 if (opaqueValue.getValue().empty())
375 return emitOpError() << "value must not be empty";
376 }
377 return success();
378}
379
380OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
381
382//===----------------------------------------------------------------------===//
383// ExpressionOp
384//===----------------------------------------------------------------------===//
385
386Operation *ExpressionOp::getRootOp() {
387 auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
388 Value yieldedValue = yieldOp.getResult();
389 Operation *rootOp = yieldedValue.getDefiningOp();
390 assert(rootOp && "Yielded value not defined within expression");
391 return rootOp;
392}
393
394LogicalResult ExpressionOp::verify() {
395 Type resultType = getResult().getType();
396 Region &region = getRegion();
397
398 Block &body = region.front();
399
400 if (!body.mightHaveTerminator())
401 return emitOpError("must yield a value at termination");
402
403 auto yield = cast<YieldOp>(body.getTerminator());
404 Value yieldResult = yield.getResult();
405
406 if (!yieldResult)
407 return emitOpError("must yield a value at termination");
408
409 Type yieldType = yieldResult.getType();
410
411 if (resultType != yieldType)
412 return emitOpError("requires yielded type to match return type");
413
414 for (Operation &op : region.front().without_terminator()) {
415 if (!op.hasTrait<OpTrait::emitc::CExpression>())
416 return emitOpError("contains an unsupported operation");
417 if (op.getNumResults() != 1)
418 return emitOpError("requires exactly one result for each operation");
419 if (!op.getResult(0).hasOneUse())
420 return emitOpError("requires exactly one use for each operation");
421 }
422
423 return success();
424}
425
426//===----------------------------------------------------------------------===//
427// ForOp
428//===----------------------------------------------------------------------===//
429
430void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
431 Value ub, Value step, BodyBuilderFn bodyBuilder) {
432 OpBuilder::InsertionGuard g(builder);
433 result.addOperands({lb, ub, step});
434 Type t = lb.getType();
435 Region *bodyRegion = result.addRegion();
436 Block *bodyBlock = builder.createBlock(bodyRegion);
437 bodyBlock->addArgument(t, result.location);
438
439 // Create the default terminator if the builder is not provided.
440 if (!bodyBuilder) {
441 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
442 } else {
443 OpBuilder::InsertionGuard guard(builder);
444 builder.setInsertionPointToStart(bodyBlock);
445 bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
446 }
447}
448
449void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
450
451ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
452 Builder &builder = parser.getBuilder();
453 Type type;
454
455 OpAsmParser::Argument inductionVariable;
456 OpAsmParser::UnresolvedOperand lb, ub, step;
457
458 // Parse the induction variable followed by '='.
459 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
460 // Parse loop bounds.
461 parser.parseOperand(lb) || parser.parseKeyword("to") ||
462 parser.parseOperand(ub) || parser.parseKeyword("step") ||
463 parser.parseOperand(step))
464 return failure();
465
466 // Parse the optional initial iteration arguments.
467 SmallVector<OpAsmParser::Argument, 4> regionArgs;
468 regionArgs.push_back(inductionVariable);
469
470 // Parse optional type, else assume Index.
471 if (parser.parseOptionalColon())
472 type = builder.getIndexType();
473 else if (parser.parseType(type))
474 return failure();
475
476 // Resolve input operands.
477 regionArgs.front().type = type;
478 if (parser.resolveOperand(lb, type, result.operands) ||
479 parser.resolveOperand(ub, type, result.operands) ||
480 parser.resolveOperand(step, type, result.operands))
481 return failure();
482
483 // Parse the body region.
484 Region *body = result.addRegion();
485 if (parser.parseRegion(*body, regionArgs))
486 return failure();
487
488 ForOp::ensureTerminator(*body, builder, result.location);
489
490 // Parse the optional attribute list.
491 if (parser.parseOptionalAttrDict(result.attributes))
492 return failure();
493
494 return success();
495}
496
497void ForOp::print(OpAsmPrinter &p) {
498 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
499 << getUpperBound() << " step " << getStep();
500
501 p << ' ';
502 if (Type t = getInductionVar().getType(); !t.isIndex())
503 p << " : " << t << ' ';
504 p.printRegion(getRegion(),
505 /*printEntryBlockArgs=*/false,
506 /*printBlockTerminators=*/false);
507 p.printOptionalAttrDict((*this)->getAttrs());
508}
509
510LogicalResult ForOp::verifyRegions() {
511 // Check that the body defines as single block argument for the induction
512 // variable.
513 if (getInductionVar().getType() != getLowerBound().getType())
514 return emitOpError(
515 "expected induction variable to be same type as bounds and step");
516
517 return success();
518}
519
520//===----------------------------------------------------------------------===//
521// CallOp
522//===----------------------------------------------------------------------===//
523
524LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
525 // Check that the callee attribute was specified.
526 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
527 if (!fnAttr)
528 return emitOpError("requires a 'callee' symbol reference attribute");
529 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
530 if (!fn)
531 return emitOpError() << "'" << fnAttr.getValue()
532 << "' does not reference a valid function";
533
534 // Verify that the operand and result types match the callee.
535 auto fnType = fn.getFunctionType();
536 if (fnType.getNumInputs() != getNumOperands())
537 return emitOpError("incorrect number of operands for callee");
538
539 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
540 if (getOperand(i).getType() != fnType.getInput(i))
541 return emitOpError("operand type mismatch: expected operand type ")
542 << fnType.getInput(i) << ", but provided "
543 << getOperand(i).getType() << " for operand number " << i;
544
545 if (fnType.getNumResults() != getNumResults())
546 return emitOpError("incorrect number of results for callee");
547
548 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
549 if (getResult(i).getType() != fnType.getResult(i)) {
550 auto diag = emitOpError("result type mismatch at index ") << i;
551 diag.attachNote() << " op result types: " << getResultTypes();
552 diag.attachNote() << "function result types: " << fnType.getResults();
553 return diag;
554 }
555
556 return success();
557}
558
559FunctionType CallOp::getCalleeType() {
560 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
561}
562
563//===----------------------------------------------------------------------===//
564// DeclareFuncOp
565//===----------------------------------------------------------------------===//
566
567LogicalResult
568DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
569 // Check that the sym_name attribute was specified.
570 auto fnAttr = getSymNameAttr();
571 if (!fnAttr)
572 return emitOpError("requires a 'sym_name' symbol reference attribute");
573 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
574 if (!fn)
575 return emitOpError() << "'" << fnAttr.getValue()
576 << "' does not reference a valid function";
577
578 return success();
579}
580
581//===----------------------------------------------------------------------===//
582// FuncOp
583//===----------------------------------------------------------------------===//
584
585void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
586 FunctionType type, ArrayRef<NamedAttribute> attrs,
587 ArrayRef<DictionaryAttr> argAttrs) {
588 state.addAttribute(SymbolTable::getSymbolAttrName(),
589 builder.getStringAttr(name));
590 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
591 state.attributes.append(attrs.begin(), attrs.end());
592 state.addRegion();
593
594 if (argAttrs.empty())
595 return;
596 assert(type.getNumInputs() == argAttrs.size());
597 call_interface_impl::addArgAndResultAttrs(
598 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
599 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
600}
601
602ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
603 auto buildFuncType =
604 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
605 function_interface_impl::VariadicFlag,
606 std::string &) { return builder.getFunctionType(argTypes, results); };
607
608 return function_interface_impl::parseFunctionOp(
609 parser, result, /*allowVariadic=*/false,
610 getFunctionTypeAttrName(result.name), buildFuncType,
611 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
612}
613
614void FuncOp::print(OpAsmPrinter &p) {
615 function_interface_impl::printFunctionOp(
616 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
617 getArgAttrsAttrName(), getResAttrsAttrName());
618}
619
620LogicalResult FuncOp::verify() {
621 if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
622 return emitOpError("cannot have lvalue type as argument");
623 }
624
625 if (getNumResults() > 1)
626 return emitOpError("requires zero or exactly one result, but has ")
627 << getNumResults();
628
629 if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
630 return emitOpError("cannot return array type");
631
632 return success();
633}
634
635//===----------------------------------------------------------------------===//
636// ReturnOp
637//===----------------------------------------------------------------------===//
638
639LogicalResult ReturnOp::verify() {
640 auto function = cast<FuncOp>((*this)->getParentOp());
641
642 // The operand number and types must match the function signature.
643 if (getNumOperands() != function.getNumResults())
644 return emitOpError("has ")
645 << getNumOperands() << " operands, but enclosing function (@"
646 << function.getName() << ") returns " << function.getNumResults();
647
648 if (function.getNumResults() == 1)
649 if (getOperand().getType() != function.getResultTypes()[0])
650 return emitError() << "type of the return operand ("
651 << getOperand().getType()
652 << ") doesn't match function result type ("
653 << function.getResultTypes()[0] << ")"
654 << " in function @" << function.getName();
655 return success();
656}
657
658//===----------------------------------------------------------------------===//
659// IfOp
660//===----------------------------------------------------------------------===//
661
662void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
663 bool addThenBlock, bool addElseBlock) {
664 assert((!addElseBlock || addThenBlock) &&
665 "must not create else block w/o then block");
666 result.addOperands(cond);
667
668 // Add regions and blocks.
669 OpBuilder::InsertionGuard guard(builder);
670 Region *thenRegion = result.addRegion();
671 if (addThenBlock)
672 builder.createBlock(thenRegion);
673 Region *elseRegion = result.addRegion();
674 if (addElseBlock)
675 builder.createBlock(elseRegion);
676}
677
678void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
679 bool withElseRegion) {
680 result.addOperands(cond);
681
682 // Build then region.
683 OpBuilder::InsertionGuard guard(builder);
684 Region *thenRegion = result.addRegion();
685 builder.createBlock(thenRegion);
686
687 // Build else region.
688 Region *elseRegion = result.addRegion();
689 if (withElseRegion) {
690 builder.createBlock(elseRegion);
691 }
692}
693
694void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
695 function_ref<void(OpBuilder &, Location)> thenBuilder,
696 function_ref<void(OpBuilder &, Location)> elseBuilder) {
697 assert(thenBuilder && "the builder callback for 'then' must be present");
698 result.addOperands(cond);
699
700 // Build then region.
701 OpBuilder::InsertionGuard guard(builder);
702 Region *thenRegion = result.addRegion();
703 builder.createBlock(thenRegion);
704 thenBuilder(builder, result.location);
705
706 // Build else region.
707 Region *elseRegion = result.addRegion();
708 if (elseBuilder) {
709 builder.createBlock(elseRegion);
710 elseBuilder(builder, result.location);
711 }
712}
713
714ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
715 // Create the regions for 'then'.
716 result.regions.reserve(2);
717 Region *thenRegion = result.addRegion();
718 Region *elseRegion = result.addRegion();
719
720 Builder &builder = parser.getBuilder();
721 OpAsmParser::UnresolvedOperand cond;
722 Type i1Type = builder.getIntegerType(1);
723 if (parser.parseOperand(cond) ||
724 parser.resolveOperand(cond, i1Type, result.operands))
725 return failure();
726 // Parse the 'then' region.
727 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
728 return failure();
729 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
730
731 // If we find an 'else' keyword then parse the 'else' region.
732 if (!parser.parseOptionalKeyword("else")) {
733 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
734 return failure();
735 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
736 }
737
738 // Parse the optional attribute list.
739 if (parser.parseOptionalAttrDict(result.attributes))
740 return failure();
741 return success();
742}
743
744void IfOp::print(OpAsmPrinter &p) {
745 bool printBlockTerminators = false;
746
747 p << " " << getCondition();
748 p << ' ';
749 p.printRegion(getThenRegion(),
750 /*printEntryBlockArgs=*/false,
751 /*printBlockTerminators=*/printBlockTerminators);
752
753 // Print the 'else' regions if it exists and has a block.
754 Region &elseRegion = getElseRegion();
755 if (!elseRegion.empty()) {
756 p << " else ";
757 p.printRegion(elseRegion,
758 /*printEntryBlockArgs=*/false,
759 /*printBlockTerminators=*/printBlockTerminators);
760 }
761
762 p.printOptionalAttrDict((*this)->getAttrs());
763}
764
765/// Given the region at `index`, or the parent operation if `index` is None,
766/// return the successor regions. These are the regions that may be selected
767/// during the flow of control. `operands` is a set of optional attributes
768/// that correspond to a constant value for each operand, or null if that
769/// operand is not a constant.
770void IfOp::getSuccessorRegions(RegionBranchPoint point,
771 SmallVectorImpl<RegionSuccessor> &regions) {
772 // The `then` and the `else` region branch back to the parent operation.
773 if (!point.isParent()) {
774 regions.push_back(RegionSuccessor());
775 return;
776 }
777
778 regions.push_back(RegionSuccessor(&getThenRegion()));
779
780 // Don't consider the else region if it is empty.
781 Region *elseRegion = &this->getElseRegion();
782 if (elseRegion->empty())
783 regions.push_back(RegionSuccessor());
784 else
785 regions.push_back(RegionSuccessor(elseRegion));
786}
787
788void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
789 SmallVectorImpl<RegionSuccessor> &regions) {
790 FoldAdaptor adaptor(operands, *this);
791 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
792 if (!boolAttr || boolAttr.getValue())
793 regions.emplace_back(&getThenRegion());
794
795 // If the else region is empty, execution continues after the parent op.
796 if (!boolAttr || !boolAttr.getValue()) {
797 if (!getElseRegion().empty())
798 regions.emplace_back(&getElseRegion());
799 else
800 regions.emplace_back();
801 }
802}
803
804void IfOp::getRegionInvocationBounds(
805 ArrayRef<Attribute> operands,
806 SmallVectorImpl<InvocationBounds> &invocationBounds) {
807 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
808 // If the condition is known, then one region is known to be executed once
809 // and the other zero times.
810 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
811 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
812 } else {
813 // Non-constant condition. Each region may be executed 0 or 1 times.
814 invocationBounds.assign(2, {0, 1});
815 }
816}
817
818//===----------------------------------------------------------------------===//
819// IncludeOp
820//===----------------------------------------------------------------------===//
821
822void IncludeOp::print(OpAsmPrinter &p) {
823 bool standardInclude = getIsStandardInclude();
824
825 p << " ";
826 if (standardInclude)
827 p << "<";
828 p << "\"" << getInclude() << "\"";
829 if (standardInclude)
830 p << ">";
831}
832
833ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
834 bool standardInclude = !parser.parseOptionalLess();
835
836 StringAttr include;
837 OptionalParseResult includeParseResult =
838 parser.parseOptionalAttribute(include, "include", result.attributes);
839 if (!includeParseResult.has_value())
840 return parser.emitError(parser.getNameLoc()) << "expected string attribute";
841
842 if (standardInclude && parser.parseOptionalGreater())
843 return parser.emitError(parser.getNameLoc())
844 << "expected trailing '>' for standard include";
845
846 if (standardInclude)
847 result.addAttribute("is_standard_include",
848 UnitAttr::get(parser.getContext()));
849
850 return success();
851}
852
853//===----------------------------------------------------------------------===//
854// LiteralOp
855//===----------------------------------------------------------------------===//
856
857/// The literal op requires a non-empty value.
858LogicalResult emitc::LiteralOp::verify() {
859 if (getValue().empty())
860 return emitOpError() << "value must not be empty";
861 return success();
862}
863//===----------------------------------------------------------------------===//
864// SubOp
865//===----------------------------------------------------------------------===//
866
867LogicalResult SubOp::verify() {
868 Type lhsType = getLhs().getType();
869 Type rhsType = getRhs().getType();
870 Type resultType = getResult().getType();
871
872 if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
873 return emitOpError("rhs can only be a pointer if lhs is a pointer");
874
875 if (isa<emitc::PointerType>(lhsType) &&
876 !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
877 return emitOpError("requires that rhs is an integer, pointer or of opaque "
878 "type if lhs is a pointer");
879
880 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
881 !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
882 return emitOpError("requires that the result is an integer, ptrdiff_t or "
883 "of opaque type if lhs and rhs are pointers");
884 return success();
885}
886
887//===----------------------------------------------------------------------===//
888// VariableOp
889//===----------------------------------------------------------------------===//
890
891LogicalResult emitc::VariableOp::verify() {
892 return verifyInitializationAttribute(getOperation(), getValueAttr());
893}
894
895//===----------------------------------------------------------------------===//
896// YieldOp
897//===----------------------------------------------------------------------===//
898
899LogicalResult emitc::YieldOp::verify() {
900 Value result = getResult();
901 Operation *containingOp = getOperation()->getParentOp();
902
903 if (result && containingOp->getNumResults() != 1)
904 return emitOpError() << "yields a value not returned by parent";
905
906 if (!result && containingOp->getNumResults() != 0)
907 return emitOpError() << "does not yield a value to be returned by parent";
908
909 return success();
910}
911
912//===----------------------------------------------------------------------===//
913// SubscriptOp
914//===----------------------------------------------------------------------===//
915
916LogicalResult emitc::SubscriptOp::verify() {
917 // Checks for array operand.
918 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
919 // Check number of indices.
920 if (getIndices().size() != (size_t)arrayType.getRank()) {
921 return emitOpError() << "on array operand requires number of indices ("
922 << getIndices().size()
923 << ") to match the rank of the array type ("
924 << arrayType.getRank() << ")";
925 }
926 // Check types of index operands.
927 for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
928 Type type = getIndices()[i].getType();
929 if (!isIntegerIndexOrOpaqueType(type)) {
930 return emitOpError() << "on array operand requires index operand " << i
931 << " to be integer-like, but got " << type;
932 }
933 }
934 // Check element type.
935 Type elementType = arrayType.getElementType();
936 Type resultType = getType().getValueType();
937 if (elementType != resultType) {
938 return emitOpError() << "on array operand requires element type ("
939 << elementType << ") and result type (" << resultType
940 << ") to match";
941 }
942 return success();
943 }
944
945 // Checks for pointer operand.
946 if (auto pointerType =
947 llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
948 // Check number of indices.
949 if (getIndices().size() != 1) {
950 return emitOpError()
951 << "on pointer operand requires one index operand, but got "
952 << getIndices().size();
953 }
954 // Check types of index operand.
955 Type type = getIndices()[0].getType();
956 if (!isIntegerIndexOrOpaqueType(type)) {
957 return emitOpError() << "on pointer operand requires index operand to be "
958 "integer-like, but got "
959 << type;
960 }
961 // Check pointee type.
962 Type pointeeType = pointerType.getPointee();
963 Type resultType = getType().getValueType();
964 if (pointeeType != resultType) {
965 return emitOpError() << "on pointer operand requires pointee type ("
966 << pointeeType << ") and result type (" << resultType
967 << ") to match";
968 }
969 return success();
970 }
971
972 // The operand has opaque type, so we can't assume anything about the number
973 // or types of index operands.
974 return success();
975}
976
977//===----------------------------------------------------------------------===//
978// VerbatimOp
979//===----------------------------------------------------------------------===//
980
981LogicalResult emitc::VerbatimOp::verify() {
982 auto errorCallback = [&]() -> InFlightDiagnostic {
983 return this->emitOpError();
984 };
985 FailureOr<SmallVector<ReplacementItem>> fmt =
986 ::parseFormatString(getValue(), getFmtArgs(), errorCallback);
987 if (failed(fmt))
988 return failure();
989
990 size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
991 return std::holds_alternative<Placeholder>(item);
992 });
993
994 if (numPlaceholders != getFmtArgs().size()) {
995 return emitOpError()
996 << "requires operands for each placeholder in the format string";
997 }
998 return success();
999}
1000
1001FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1002 // Error checking is done in verify.
1003 return ::parseFormatString(getValue(), getFmtArgs());
1004}
1005
1006//===----------------------------------------------------------------------===//
1007// EmitC Enums
1008//===----------------------------------------------------------------------===//
1009
1010#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
1011
1012//===----------------------------------------------------------------------===//
1013// EmitC Attributes
1014//===----------------------------------------------------------------------===//
1015
1016#define GET_ATTRDEF_CLASSES
1017#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
1018
1019//===----------------------------------------------------------------------===//
1020// EmitC Types
1021//===----------------------------------------------------------------------===//
1022
1023#define GET_TYPEDEF_CLASSES
1024#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
1025
1026//===----------------------------------------------------------------------===//
1027// ArrayType
1028//===----------------------------------------------------------------------===//
1029
1030Type emitc::ArrayType::parse(AsmParser &parser) {
1031 if (parser.parseLess())
1032 return Type();
1033
1034 SmallVector<int64_t, 4> dimensions;
1035 if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
1036 /*withTrailingX=*/true))
1037 return Type();
1038 // Parse the element type.
1039 auto typeLoc = parser.getCurrentLocation();
1040 Type elementType;
1041 if (parser.parseType(elementType))
1042 return Type();
1043
1044 // Check that array is formed from allowed types.
1045 if (!isValidElementType(elementType))
1046 return parser.emitError(typeLoc, "invalid array element type '")
1047 << elementType << "'",
1048 Type();
1049 if (parser.parseGreater())
1050 return Type();
1051 return parser.getChecked<ArrayType>(dimensions, elementType);
1052}
1053
1054void emitc::ArrayType::print(AsmPrinter &printer) const {
1055 printer << "<";
1056 for (int64_t dim : getShape()) {
1057 printer << dim << 'x';
1058 }
1059 printer.printType(getElementType());
1060 printer << ">";
1061}
1062
1063LogicalResult emitc::ArrayType::verify(
1064 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
1065 ::llvm::ArrayRef<int64_t> shape, Type elementType) {
1066 if (shape.empty())
1067 return emitError() << "shape must not be empty";
1068
1069 for (int64_t dim : shape) {
1070 if (dim < 0)
1071 return emitError() << "dimensions must have non-negative size";
1072 }
1073
1074 if (!elementType)
1075 return emitError() << "element type must not be none";
1076
1077 if (!isValidElementType(elementType))
1078 return emitError() << "invalid array element type";
1079
1080 return success();
1081}
1082
1083emitc::ArrayType
1084emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
1085 Type elementType) const {
1086 if (!shape)
1087 return emitc::ArrayType::get(getShape(), elementType);
1088 return emitc::ArrayType::get(*shape, elementType);
1089}
1090
1091//===----------------------------------------------------------------------===//
1092// LValueType
1093//===----------------------------------------------------------------------===//
1094
1095LogicalResult mlir::emitc::LValueType::verify(
1096 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1097 mlir::Type value) {
1098 // Check that the wrapped type is valid. This especially forbids nested
1099 // lvalue types.
1100 if (!isSupportedEmitCType(value))
1101 return emitError()
1102 << "!emitc.lvalue must wrap supported emitc type, but got " << value;
1103
1104 if (llvm::isa<emitc::ArrayType>(value))
1105 return emitError() << "!emitc.lvalue cannot wrap !emitc.array type";
1106
1107 return success();
1108}
1109
1110//===----------------------------------------------------------------------===//
1111// OpaqueType
1112//===----------------------------------------------------------------------===//
1113
1114LogicalResult mlir::emitc::OpaqueType::verify(
1115 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1116 llvm::StringRef value) {
1117 if (value.empty()) {
1118 return emitError() << "expected non empty string in !emitc.opaque type";
1119 }
1120 if (value.back() == '*') {
1121 return emitError() << "pointer not allowed as outer type with "
1122 "!emitc.opaque, use !emitc.ptr instead";
1123 }
1124 return success();
1125}
1126
1127//===----------------------------------------------------------------------===//
1128// PointerType
1129//===----------------------------------------------------------------------===//
1130
1131LogicalResult mlir::emitc::PointerType::verify(
1132 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, Type value) {
1133 if (llvm::isa<emitc::LValueType>(value))
1134 return emitError() << "pointers to lvalues are not allowed";
1135
1136 return success();
1137}
1138
1139//===----------------------------------------------------------------------===//
1140// GlobalOp
1141//===----------------------------------------------------------------------===//
1142static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1143 TypeAttr type,
1144 Attribute initialValue) {
1145 p << type;
1146 if (initialValue) {
1147 p << " = ";
1148 p.printAttributeWithoutType(attr: initialValue);
1149 }
1150}
1151
1152static Type getInitializerTypeForGlobal(Type type) {
1153 if (auto array = llvm::dyn_cast<ArrayType>(type))
1154 return RankedTensorType::get(array.getShape(), array.getElementType());
1155 return type;
1156}
1157
1158static ParseResult
1159parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1160 Attribute &initialValue) {
1161 Type type;
1162 if (parser.parseType(result&: type))
1163 return failure();
1164
1165 typeAttr = TypeAttr::get(type);
1166
1167 if (parser.parseOptionalEqual())
1168 return success();
1169
1170 if (parser.parseAttribute(result&: initialValue, type: getInitializerTypeForGlobal(type)))
1171 return failure();
1172
1173 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1174 initialValue))
1175 return parser.emitError(loc: parser.getNameLoc())
1176 << "initial value should be a integer, float, elements or opaque "
1177 "attribute";
1178 return success();
1179}
1180
1181LogicalResult GlobalOp::verify() {
1182 if (!isSupportedEmitCType(getType())) {
1183 return emitOpError("expected valid emitc type");
1184 }
1185 if (getInitialValue().has_value()) {
1186 Attribute initValue = getInitialValue().value();
1187 // Check that the type of the initial value is compatible with the type of
1188 // the global variable.
1189 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1190 auto arrayType = llvm::dyn_cast<ArrayType>(getType());
1191 if (!arrayType)
1192 return emitOpError("expected array type, but got ") << getType();
1193
1194 Type initType = elementsAttr.getType();
1195 Type tensorType = getInitializerTypeForGlobal(getType());
1196 if (initType != tensorType) {
1197 return emitOpError("initial value expected to be of type ")
1198 << getType() << ", but was of type " << initType;
1199 }
1200 } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1201 if (intAttr.getType() != getType()) {
1202 return emitOpError("initial value expected to be of type ")
1203 << getType() << ", but was of type " << intAttr.getType();
1204 }
1205 } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1206 if (floatAttr.getType() != getType()) {
1207 return emitOpError("initial value expected to be of type ")
1208 << getType() << ", but was of type " << floatAttr.getType();
1209 }
1210 } else if (!isa<emitc::OpaqueAttr>(initValue)) {
1211 return emitOpError("initial value should be a integer, float, elements "
1212 "or opaque attribute, but got ")
1213 << initValue;
1214 }
1215 }
1216 if (getStaticSpecifier() && getExternSpecifier()) {
1217 return emitOpError("cannot have both static and extern specifiers");
1218 }
1219 return success();
1220}
1221
1222//===----------------------------------------------------------------------===//
1223// GetGlobalOp
1224//===----------------------------------------------------------------------===//
1225
1226LogicalResult
1227GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1228 // Verify that the type matches the type of the global variable.
1229 auto global =
1230 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1231 if (!global)
1232 return emitOpError("'")
1233 << getName() << "' does not reference a valid emitc.global";
1234
1235 Type resultType = getResult().getType();
1236 Type globalType = global.getType();
1237
1238 // global has array type
1239 if (llvm::isa<ArrayType>(globalType)) {
1240 if (globalType != resultType)
1241 return emitOpError("on array type expects result type ")
1242 << resultType << " to match type " << globalType
1243 << " of the global @" << getName();
1244 return success();
1245 }
1246
1247 // global has non-array type
1248 auto lvalueType = dyn_cast<LValueType>(resultType);
1249 if (!lvalueType || lvalueType.getValueType() != globalType)
1250 return emitOpError("on non-array type expects result inner type ")
1251 << lvalueType.getValueType() << " to match type " << globalType
1252 << " of the global @" << getName();
1253 return success();
1254}
1255
1256//===----------------------------------------------------------------------===//
1257// SwitchOp
1258//===----------------------------------------------------------------------===//
1259
1260/// Parse the case regions and values.
1261static ParseResult
1262parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
1263 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1264 SmallVector<int64_t> caseValues;
1265 while (succeeded(Result: parser.parseOptionalKeyword(keyword: "case"))) {
1266 int64_t value;
1267 Region &region = *caseRegions.emplace_back(Args: std::make_unique<Region>());
1268 if (parser.parseInteger(result&: value) ||
1269 parser.parseRegion(region, /*arguments=*/{}))
1270 return failure();
1271 caseValues.push_back(Elt: value);
1272 }
1273 cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1274 return success();
1275}
1276
1277/// Print the case regions and values.
1278static void printSwitchCases(OpAsmPrinter &p, Operation *op,
1279 DenseI64ArrayAttr cases, RegionRange caseRegions) {
1280 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1281 p.printNewline();
1282 p << "case " << value << ' ';
1283 p.printRegion(*region, /*printEntryBlockArgs=*/false);
1284 }
1285}
1286
1287static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1288 const Twine &name) {
1289 auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1290 if (!yield)
1291 return op.emitOpError("expected region to end with emitc.yield, but got ")
1292 << region.front().back().getName();
1293
1294 if (yield.getNumOperands() != 0) {
1295 return (op.emitOpError("expected each region to return ")
1296 << "0 values, but " << name << " returns "
1297 << yield.getNumOperands())
1298 .attachNote(yield.getLoc())
1299 << "see yield operation here";
1300 }
1301
1302 return success();
1303}
1304
1305LogicalResult emitc::SwitchOp::verify() {
1306 if (!isIntegerIndexOrOpaqueType(getArg().getType()))
1307 return emitOpError("unsupported type ") << getArg().getType();
1308
1309 if (getCases().size() != getCaseRegions().size()) {
1310 return emitOpError("has ")
1311 << getCaseRegions().size() << " case regions but "
1312 << getCases().size() << " case values";
1313 }
1314
1315 DenseSet<int64_t> valueSet;
1316 for (int64_t value : getCases())
1317 if (!valueSet.insert(value).second)
1318 return emitOpError("has duplicate case value: ") << value;
1319
1320 if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1321 return failure();
1322
1323 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1324 if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1325 return failure();
1326
1327 return success();
1328}
1329
1330unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1331
1332Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1333
1334Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1335 assert(idx < getNumCases() && "case index out-of-bounds");
1336 return getCaseRegions()[idx].front();
1337}
1338
1339void SwitchOp::getSuccessorRegions(
1340 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1341 llvm::append_range(successors, getRegions());
1342}
1343
1344void SwitchOp::getEntrySuccessorRegions(
1345 ArrayRef<Attribute> operands,
1346 SmallVectorImpl<RegionSuccessor> &successors) {
1347 FoldAdaptor adaptor(operands, *this);
1348
1349 // If a constant was not provided, all regions are possible successors.
1350 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1351 if (!arg) {
1352 llvm::append_range(successors, getRegions());
1353 return;
1354 }
1355
1356 // Otherwise, try to find a case with a matching value. If not, the
1357 // default region is the only successor.
1358 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1359 if (caseValue == arg.getInt()) {
1360 successors.emplace_back(&caseRegion);
1361 return;
1362 }
1363 }
1364 successors.emplace_back(&getDefaultRegion());
1365}
1366
1367void SwitchOp::getRegionInvocationBounds(
1368 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1369 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1370 if (!operandValue) {
1371 // All regions are invoked at most once.
1372 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1373 return;
1374 }
1375
1376 unsigned liveIndex = getNumRegions() - 1;
1377 const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1378
1379 liveIndex = iteratorToInt != getCases().end()
1380 ? std::distance(getCases().begin(), iteratorToInt)
1381 : liveIndex;
1382
1383 for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1384 ++regIndex)
1385 bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1386}
1387
1388//===----------------------------------------------------------------------===//
1389// FileOp
1390//===----------------------------------------------------------------------===//
1391void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
1392 state.addRegion()->emplaceBlock();
1393 state.attributes.push_back(
1394 builder.getNamedAttr("id", builder.getStringAttr(id)));
1395}
1396
1397//===----------------------------------------------------------------------===//
1398// TableGen'd op method definitions
1399//===----------------------------------------------------------------------===//
1400
1401#define GET_OP_CLASSES
1402#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
1403

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/EmitC/IR/EmitC.cpp