1//===- EmitC.cpp - EmitC Dialect ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/EmitC/IR/EmitC.h"
10#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Types.h"
17#include "mlir/Interfaces/FunctionImplementation.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Casting.h"
22
23using namespace mlir;
24using namespace mlir::emitc;
25
26#include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
27
28//===----------------------------------------------------------------------===//
29// EmitCDialect
30//===----------------------------------------------------------------------===//
31
32void EmitCDialect::initialize() {
33 addOperations<
34#define GET_OP_LIST
35#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
36 >();
37 addTypes<
38#define GET_TYPEDEF_LIST
39#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
40 >();
41 addAttributes<
42#define GET_ATTRDEF_LIST
43#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
44 >();
45}
46
47/// Materialize a single constant operation from a given attribute value with
48/// the desired resultant type.
49Operation *EmitCDialect::materializeConstant(OpBuilder &builder,
50 Attribute value, Type type,
51 Location loc) {
52 return builder.create<emitc::ConstantOp>(loc, type, value);
53}
54
55/// Default callback for builders of ops carrying a region. Inserts a yield
56/// without arguments.
57void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
58 builder.create<emitc::YieldOp>(loc);
59}
60
61bool mlir::emitc::isSupportedEmitCType(Type type) {
62 if (llvm::isa<emitc::OpaqueType>(type))
63 return true;
64 if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
65 return isSupportedEmitCType(ptrType.getPointee());
66 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
67 auto elemType = arrayType.getElementType();
68 return !llvm::isa<emitc::ArrayType>(elemType) &&
69 isSupportedEmitCType(elemType);
70 }
71 if (type.isIndex())
72 return true;
73 if (llvm::isa<IntegerType>(Val: type))
74 return isSupportedIntegerType(type);
75 if (llvm::isa<FloatType>(Val: type))
76 return isSupportedFloatType(type);
77 if (auto tensorType = llvm::dyn_cast<TensorType>(Val&: type)) {
78 if (!tensorType.hasStaticShape()) {
79 return false;
80 }
81 auto elemType = tensorType.getElementType();
82 if (llvm::isa<emitc::ArrayType>(elemType)) {
83 return false;
84 }
85 return isSupportedEmitCType(type: elemType);
86 }
87 if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
88 return llvm::all_of(tupleType.getTypes(), [](Type type) {
89 return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
90 });
91 }
92 return false;
93}
94
95bool mlir::emitc::isSupportedIntegerType(Type type) {
96 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
97 switch (intType.getWidth()) {
98 case 1:
99 case 8:
100 case 16:
101 case 32:
102 case 64:
103 return true;
104 default:
105 return false;
106 }
107 }
108 return false;
109}
110
111bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
112 return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
113 isSupportedIntegerType(type);
114}
115
116bool mlir::emitc::isSupportedFloatType(Type type) {
117 if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) {
118 switch (floatType.getWidth()) {
119 case 32:
120 case 64:
121 return true;
122 default:
123 return false;
124 }
125 }
126 return false;
127}
128
129/// Check that the type of the initial value is compatible with the operations
130/// result type.
131static LogicalResult verifyInitializationAttribute(Operation *op,
132 Attribute value) {
133 assert(op->getNumResults() == 1 && "operation must have 1 result");
134
135 if (llvm::isa<emitc::OpaqueAttr>(value))
136 return success();
137
138 if (llvm::isa<StringAttr>(Val: value))
139 return op->emitOpError()
140 << "string attributes are not supported, use #emitc.opaque instead";
141
142 Type resultType = op->getResult(idx: 0).getType();
143 Type attrType = cast<TypedAttr>(value).getType();
144
145 if (resultType != attrType)
146 return op->emitOpError()
147 << "requires attribute to either be an #emitc.opaque attribute or "
148 "it's type ("
149 << attrType << ") to match the op's result type (" << resultType
150 << ")";
151
152 return success();
153}
154
155//===----------------------------------------------------------------------===//
156// AddOp
157//===----------------------------------------------------------------------===//
158
159LogicalResult AddOp::verify() {
160 Type lhsType = getLhs().getType();
161 Type rhsType = getRhs().getType();
162
163 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
164 return emitOpError("requires that at most one operand is a pointer");
165
166 if ((isa<emitc::PointerType>(lhsType) &&
167 !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
168 (isa<emitc::PointerType>(rhsType) &&
169 !isa<IntegerType, emitc::OpaqueType>(lhsType)))
170 return emitOpError("requires that one operand is an integer or of opaque "
171 "type if the other is a pointer");
172
173 return success();
174}
175
176//===----------------------------------------------------------------------===//
177// ApplyOp
178//===----------------------------------------------------------------------===//
179
180LogicalResult ApplyOp::verify() {
181 StringRef applicableOperatorStr = getApplicableOperator();
182
183 // Applicable operator must not be empty.
184 if (applicableOperatorStr.empty())
185 return emitOpError("applicable operator must not be empty");
186
187 // Only `*` and `&` are supported.
188 if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
189 return emitOpError("applicable operator is illegal");
190
191 Operation *op = getOperand().getDefiningOp();
192 if (op && dyn_cast<ConstantOp>(op))
193 return emitOpError("cannot apply to constant");
194
195 return success();
196}
197
198//===----------------------------------------------------------------------===//
199// AssignOp
200//===----------------------------------------------------------------------===//
201
202/// The assign op requires that the assigned value's type matches the
203/// assigned-to variable type.
204LogicalResult emitc::AssignOp::verify() {
205 Value variable = getVar();
206 Operation *variableDef = variable.getDefiningOp();
207 if (!variableDef ||
208 !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
209 return emitOpError() << "requires first operand (" << variable
210 << ") to be a Variable or subscript";
211
212 Value value = getValue();
213 if (variable.getType() != value.getType())
214 return emitOpError() << "requires value's type (" << value.getType()
215 << ") to match variable's type (" << variable.getType()
216 << ")";
217 if (isa<ArrayType>(variable.getType()))
218 return emitOpError() << "cannot assign to array type";
219 return success();
220}
221
222//===----------------------------------------------------------------------===//
223// CastOp
224//===----------------------------------------------------------------------===//
225
226bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
227 Type input = inputs.front(), output = outputs.front();
228
229 return ((llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
230 emitc::PointerType>(input)) &&
231 (llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
232 emitc::PointerType>(output)));
233}
234
235//===----------------------------------------------------------------------===//
236// CallOp
237//===----------------------------------------------------------------------===//
238
239LogicalResult emitc::CallOpaqueOp::verify() {
240 // Callee must not be empty.
241 if (getCallee().empty())
242 return emitOpError("callee must not be empty");
243
244 if (std::optional<ArrayAttr> argsAttr = getArgs()) {
245 for (Attribute arg : *argsAttr) {
246 auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
247 if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
248 int64_t index = intAttr.getInt();
249 // Args with elements of type index must be in range
250 // [0..operands.size).
251 if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
252 return emitOpError("index argument is out of range");
253
254 // Args with elements of type ArrayAttr must have a type.
255 } else if (llvm::isa<ArrayAttr>(
256 arg) /*&& llvm::isa<NoneType>(arg.getType())*/) {
257 // FIXME: Array attributes never have types
258 return emitOpError("array argument has no type");
259 }
260 }
261 }
262
263 if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
264 for (Attribute tArg : *templateArgsAttr) {
265 if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
266 return emitOpError("template argument has invalid type");
267 }
268 }
269
270 if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
271 return emitOpError() << "cannot return array type";
272 }
273
274 return success();
275}
276
277//===----------------------------------------------------------------------===//
278// ConstantOp
279//===----------------------------------------------------------------------===//
280
281LogicalResult emitc::ConstantOp::verify() {
282 Attribute value = getValueAttr();
283 if (failed(verifyInitializationAttribute(getOperation(), value)))
284 return failure();
285 if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
286 if (opaqueValue.getValue().empty())
287 return emitOpError() << "value must not be empty";
288 }
289 return success();
290}
291
292OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
293
294//===----------------------------------------------------------------------===//
295// ExpressionOp
296//===----------------------------------------------------------------------===//
297
298Operation *ExpressionOp::getRootOp() {
299 auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
300 Value yieldedValue = yieldOp.getResult();
301 Operation *rootOp = yieldedValue.getDefiningOp();
302 assert(rootOp && "Yielded value not defined within expression");
303 return rootOp;
304}
305
306LogicalResult ExpressionOp::verify() {
307 Type resultType = getResult().getType();
308 Region &region = getRegion();
309
310 Block &body = region.front();
311
312 if (!body.mightHaveTerminator())
313 return emitOpError("must yield a value at termination");
314
315 auto yield = cast<YieldOp>(body.getTerminator());
316 Value yieldResult = yield.getResult();
317
318 if (!yieldResult)
319 return emitOpError("must yield a value at termination");
320
321 Type yieldType = yieldResult.getType();
322
323 if (resultType != yieldType)
324 return emitOpError("requires yielded type to match return type");
325
326 for (Operation &op : region.front().without_terminator()) {
327 if (!op.hasTrait<OpTrait::emitc::CExpression>())
328 return emitOpError("contains an unsupported operation");
329 if (op.getNumResults() != 1)
330 return emitOpError("requires exactly one result for each operation");
331 if (!op.getResult(0).hasOneUse())
332 return emitOpError("requires exactly one use for each operation");
333 }
334
335 return success();
336}
337
338//===----------------------------------------------------------------------===//
339// ForOp
340//===----------------------------------------------------------------------===//
341
342void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
343 Value ub, Value step, BodyBuilderFn bodyBuilder) {
344 OpBuilder::InsertionGuard g(builder);
345 result.addOperands({lb, ub, step});
346 Type t = lb.getType();
347 Region *bodyRegion = result.addRegion();
348 Block *bodyBlock = builder.createBlock(bodyRegion);
349 bodyBlock->addArgument(t, result.location);
350
351 // Create the default terminator if the builder is not provided.
352 if (!bodyBuilder) {
353 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
354 } else {
355 OpBuilder::InsertionGuard guard(builder);
356 builder.setInsertionPointToStart(bodyBlock);
357 bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
358 }
359}
360
361void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
362
363ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
364 Builder &builder = parser.getBuilder();
365 Type type;
366
367 OpAsmParser::Argument inductionVariable;
368 OpAsmParser::UnresolvedOperand lb, ub, step;
369
370 // Parse the induction variable followed by '='.
371 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
372 // Parse loop bounds.
373 parser.parseOperand(lb) || parser.parseKeyword("to") ||
374 parser.parseOperand(ub) || parser.parseKeyword("step") ||
375 parser.parseOperand(step))
376 return failure();
377
378 // Parse the optional initial iteration arguments.
379 SmallVector<OpAsmParser::Argument, 4> regionArgs;
380 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
381 regionArgs.push_back(inductionVariable);
382
383 // Parse optional type, else assume Index.
384 if (parser.parseOptionalColon())
385 type = builder.getIndexType();
386 else if (parser.parseType(type))
387 return failure();
388
389 // Resolve input operands.
390 regionArgs.front().type = type;
391 if (parser.resolveOperand(lb, type, result.operands) ||
392 parser.resolveOperand(ub, type, result.operands) ||
393 parser.resolveOperand(step, type, result.operands))
394 return failure();
395
396 // Parse the body region.
397 Region *body = result.addRegion();
398 if (parser.parseRegion(*body, regionArgs))
399 return failure();
400
401 ForOp::ensureTerminator(*body, builder, result.location);
402
403 // Parse the optional attribute list.
404 if (parser.parseOptionalAttrDict(result.attributes))
405 return failure();
406
407 return success();
408}
409
410void ForOp::print(OpAsmPrinter &p) {
411 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
412 << getUpperBound() << " step " << getStep();
413
414 p << ' ';
415 if (Type t = getInductionVar().getType(); !t.isIndex())
416 p << " : " << t << ' ';
417 p.printRegion(getRegion(),
418 /*printEntryBlockArgs=*/false,
419 /*printBlockTerminators=*/false);
420 p.printOptionalAttrDict((*this)->getAttrs());
421}
422
423LogicalResult ForOp::verifyRegions() {
424 // Check that the body defines as single block argument for the induction
425 // variable.
426 if (getInductionVar().getType() != getLowerBound().getType())
427 return emitOpError(
428 "expected induction variable to be same type as bounds and step");
429
430 return success();
431}
432
433//===----------------------------------------------------------------------===//
434// CallOp
435//===----------------------------------------------------------------------===//
436
437LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
438 // Check that the callee attribute was specified.
439 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
440 if (!fnAttr)
441 return emitOpError("requires a 'callee' symbol reference attribute");
442 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
443 if (!fn)
444 return emitOpError() << "'" << fnAttr.getValue()
445 << "' does not reference a valid function";
446
447 // Verify that the operand and result types match the callee.
448 auto fnType = fn.getFunctionType();
449 if (fnType.getNumInputs() != getNumOperands())
450 return emitOpError("incorrect number of operands for callee");
451
452 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
453 if (getOperand(i).getType() != fnType.getInput(i))
454 return emitOpError("operand type mismatch: expected operand type ")
455 << fnType.getInput(i) << ", but provided "
456 << getOperand(i).getType() << " for operand number " << i;
457
458 if (fnType.getNumResults() != getNumResults())
459 return emitOpError("incorrect number of results for callee");
460
461 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
462 if (getResult(i).getType() != fnType.getResult(i)) {
463 auto diag = emitOpError("result type mismatch at index ") << i;
464 diag.attachNote() << " op result types: " << getResultTypes();
465 diag.attachNote() << "function result types: " << fnType.getResults();
466 return diag;
467 }
468
469 return success();
470}
471
472FunctionType CallOp::getCalleeType() {
473 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
474}
475
476//===----------------------------------------------------------------------===//
477// DeclareFuncOp
478//===----------------------------------------------------------------------===//
479
480LogicalResult
481DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
482 // Check that the sym_name attribute was specified.
483 auto fnAttr = getSymNameAttr();
484 if (!fnAttr)
485 return emitOpError("requires a 'sym_name' symbol reference attribute");
486 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
487 if (!fn)
488 return emitOpError() << "'" << fnAttr.getValue()
489 << "' does not reference a valid function";
490
491 return success();
492}
493
494//===----------------------------------------------------------------------===//
495// FuncOp
496//===----------------------------------------------------------------------===//
497
498void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
499 FunctionType type, ArrayRef<NamedAttribute> attrs,
500 ArrayRef<DictionaryAttr> argAttrs) {
501 state.addAttribute(SymbolTable::getSymbolAttrName(),
502 builder.getStringAttr(name));
503 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
504 state.attributes.append(attrs.begin(), attrs.end());
505 state.addRegion();
506
507 if (argAttrs.empty())
508 return;
509 assert(type.getNumInputs() == argAttrs.size());
510 function_interface_impl::addArgAndResultAttrs(
511 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
512 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
513}
514
515ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
516 auto buildFuncType =
517 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
518 function_interface_impl::VariadicFlag,
519 std::string &) { return builder.getFunctionType(argTypes, results); };
520
521 return function_interface_impl::parseFunctionOp(
522 parser, result, /*allowVariadic=*/false,
523 getFunctionTypeAttrName(result.name), buildFuncType,
524 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
525}
526
527void FuncOp::print(OpAsmPrinter &p) {
528 function_interface_impl::printFunctionOp(
529 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
530 getArgAttrsAttrName(), getResAttrsAttrName());
531}
532
533LogicalResult FuncOp::verify() {
534 if (getNumResults() > 1)
535 return emitOpError("requires zero or exactly one result, but has ")
536 << getNumResults();
537
538 if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
539 return emitOpError("cannot return array type");
540
541 return success();
542}
543
544//===----------------------------------------------------------------------===//
545// ReturnOp
546//===----------------------------------------------------------------------===//
547
548LogicalResult ReturnOp::verify() {
549 auto function = cast<FuncOp>((*this)->getParentOp());
550
551 // The operand number and types must match the function signature.
552 if (getNumOperands() != function.getNumResults())
553 return emitOpError("has ")
554 << getNumOperands() << " operands, but enclosing function (@"
555 << function.getName() << ") returns " << function.getNumResults();
556
557 if (function.getNumResults() == 1)
558 if (getOperand().getType() != function.getResultTypes()[0])
559 return emitError() << "type of the return operand ("
560 << getOperand().getType()
561 << ") doesn't match function result type ("
562 << function.getResultTypes()[0] << ")"
563 << " in function @" << function.getName();
564 return success();
565}
566
567//===----------------------------------------------------------------------===//
568// IfOp
569//===----------------------------------------------------------------------===//
570
571void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
572 bool addThenBlock, bool addElseBlock) {
573 assert((!addElseBlock || addThenBlock) &&
574 "must not create else block w/o then block");
575 result.addOperands(cond);
576
577 // Add regions and blocks.
578 OpBuilder::InsertionGuard guard(builder);
579 Region *thenRegion = result.addRegion();
580 if (addThenBlock)
581 builder.createBlock(thenRegion);
582 Region *elseRegion = result.addRegion();
583 if (addElseBlock)
584 builder.createBlock(elseRegion);
585}
586
587void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
588 bool withElseRegion) {
589 result.addOperands(cond);
590
591 // Build then region.
592 OpBuilder::InsertionGuard guard(builder);
593 Region *thenRegion = result.addRegion();
594 builder.createBlock(thenRegion);
595
596 // Build else region.
597 Region *elseRegion = result.addRegion();
598 if (withElseRegion) {
599 builder.createBlock(elseRegion);
600 }
601}
602
603void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
604 function_ref<void(OpBuilder &, Location)> thenBuilder,
605 function_ref<void(OpBuilder &, Location)> elseBuilder) {
606 assert(thenBuilder && "the builder callback for 'then' must be present");
607 result.addOperands(cond);
608
609 // Build then region.
610 OpBuilder::InsertionGuard guard(builder);
611 Region *thenRegion = result.addRegion();
612 builder.createBlock(thenRegion);
613 thenBuilder(builder, result.location);
614
615 // Build else region.
616 Region *elseRegion = result.addRegion();
617 if (elseBuilder) {
618 builder.createBlock(elseRegion);
619 elseBuilder(builder, result.location);
620 }
621}
622
623ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
624 // Create the regions for 'then'.
625 result.regions.reserve(2);
626 Region *thenRegion = result.addRegion();
627 Region *elseRegion = result.addRegion();
628
629 Builder &builder = parser.getBuilder();
630 OpAsmParser::UnresolvedOperand cond;
631 Type i1Type = builder.getIntegerType(1);
632 if (parser.parseOperand(cond) ||
633 parser.resolveOperand(cond, i1Type, result.operands))
634 return failure();
635 // Parse the 'then' region.
636 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
637 return failure();
638 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
639
640 // If we find an 'else' keyword then parse the 'else' region.
641 if (!parser.parseOptionalKeyword("else")) {
642 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
643 return failure();
644 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
645 }
646
647 // Parse the optional attribute list.
648 if (parser.parseOptionalAttrDict(result.attributes))
649 return failure();
650 return success();
651}
652
653void IfOp::print(OpAsmPrinter &p) {
654 bool printBlockTerminators = false;
655
656 p << " " << getCondition();
657 p << ' ';
658 p.printRegion(getThenRegion(),
659 /*printEntryBlockArgs=*/false,
660 /*printBlockTerminators=*/printBlockTerminators);
661
662 // Print the 'else' regions if it exists and has a block.
663 Region &elseRegion = getElseRegion();
664 if (!elseRegion.empty()) {
665 p << " else ";
666 p.printRegion(elseRegion,
667 /*printEntryBlockArgs=*/false,
668 /*printBlockTerminators=*/printBlockTerminators);
669 }
670
671 p.printOptionalAttrDict((*this)->getAttrs());
672}
673
674/// Given the region at `index`, or the parent operation if `index` is None,
675/// return the successor regions. These are the regions that may be selected
676/// during the flow of control. `operands` is a set of optional attributes that
677/// correspond to a constant value for each operand, or null if that operand is
678/// not a constant.
679void IfOp::getSuccessorRegions(RegionBranchPoint point,
680 SmallVectorImpl<RegionSuccessor> &regions) {
681 // The `then` and the `else` region branch back to the parent operation.
682 if (!point.isParent()) {
683 regions.push_back(RegionSuccessor());
684 return;
685 }
686
687 regions.push_back(RegionSuccessor(&getThenRegion()));
688
689 // Don't consider the else region if it is empty.
690 Region *elseRegion = &this->getElseRegion();
691 if (elseRegion->empty())
692 regions.push_back(RegionSuccessor());
693 else
694 regions.push_back(RegionSuccessor(elseRegion));
695}
696
697void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
698 SmallVectorImpl<RegionSuccessor> &regions) {
699 FoldAdaptor adaptor(operands, *this);
700 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
701 if (!boolAttr || boolAttr.getValue())
702 regions.emplace_back(&getThenRegion());
703
704 // If the else region is empty, execution continues after the parent op.
705 if (!boolAttr || !boolAttr.getValue()) {
706 if (!getElseRegion().empty())
707 regions.emplace_back(&getElseRegion());
708 else
709 regions.emplace_back();
710 }
711}
712
713void IfOp::getRegionInvocationBounds(
714 ArrayRef<Attribute> operands,
715 SmallVectorImpl<InvocationBounds> &invocationBounds) {
716 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
717 // If the condition is known, then one region is known to be executed once
718 // and the other zero times.
719 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
720 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
721 } else {
722 // Non-constant condition. Each region may be executed 0 or 1 times.
723 invocationBounds.assign(2, {0, 1});
724 }
725}
726
727//===----------------------------------------------------------------------===//
728// IncludeOp
729//===----------------------------------------------------------------------===//
730
731void IncludeOp::print(OpAsmPrinter &p) {
732 bool standardInclude = getIsStandardInclude();
733
734 p << " ";
735 if (standardInclude)
736 p << "<";
737 p << "\"" << getInclude() << "\"";
738 if (standardInclude)
739 p << ">";
740}
741
742ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
743 bool standardInclude = !parser.parseOptionalLess();
744
745 StringAttr include;
746 OptionalParseResult includeParseResult =
747 parser.parseOptionalAttribute(include, "include", result.attributes);
748 if (!includeParseResult.has_value())
749 return parser.emitError(parser.getNameLoc()) << "expected string attribute";
750
751 if (standardInclude && parser.parseOptionalGreater())
752 return parser.emitError(parser.getNameLoc())
753 << "expected trailing '>' for standard include";
754
755 if (standardInclude)
756 result.addAttribute("is_standard_include",
757 UnitAttr::get(parser.getContext()));
758
759 return success();
760}
761
762//===----------------------------------------------------------------------===//
763// LiteralOp
764//===----------------------------------------------------------------------===//
765
766/// The literal op requires a non-empty value.
767LogicalResult emitc::LiteralOp::verify() {
768 if (getValue().empty())
769 return emitOpError() << "value must not be empty";
770 return success();
771}
772//===----------------------------------------------------------------------===//
773// SubOp
774//===----------------------------------------------------------------------===//
775
776LogicalResult SubOp::verify() {
777 Type lhsType = getLhs().getType();
778 Type rhsType = getRhs().getType();
779 Type resultType = getResult().getType();
780
781 if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
782 return emitOpError("rhs can only be a pointer if lhs is a pointer");
783
784 if (isa<emitc::PointerType>(lhsType) &&
785 !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
786 return emitOpError("requires that rhs is an integer, pointer or of opaque "
787 "type if lhs is a pointer");
788
789 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
790 !isa<IntegerType, emitc::OpaqueType>(resultType))
791 return emitOpError("requires that the result is an integer or of opaque "
792 "type if lhs and rhs are pointers");
793 return success();
794}
795
796//===----------------------------------------------------------------------===//
797// VariableOp
798//===----------------------------------------------------------------------===//
799
800LogicalResult emitc::VariableOp::verify() {
801 return verifyInitializationAttribute(getOperation(), getValueAttr());
802}
803
804//===----------------------------------------------------------------------===//
805// YieldOp
806//===----------------------------------------------------------------------===//
807
808LogicalResult emitc::YieldOp::verify() {
809 Value result = getResult();
810 Operation *containingOp = getOperation()->getParentOp();
811
812 if (result && containingOp->getNumResults() != 1)
813 return emitOpError() << "yields a value not returned by parent";
814
815 if (!result && containingOp->getNumResults() != 0)
816 return emitOpError() << "does not yield a value to be returned by parent";
817
818 return success();
819}
820
821//===----------------------------------------------------------------------===//
822// SubscriptOp
823//===----------------------------------------------------------------------===//
824
825LogicalResult emitc::SubscriptOp::verify() {
826 // Checks for array operand.
827 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
828 // Check number of indices.
829 if (getIndices().size() != (size_t)arrayType.getRank()) {
830 return emitOpError() << "on array operand requires number of indices ("
831 << getIndices().size()
832 << ") to match the rank of the array type ("
833 << arrayType.getRank() << ")";
834 }
835 // Check types of index operands.
836 for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
837 Type type = getIndices()[i].getType();
838 if (!isIntegerIndexOrOpaqueType(type)) {
839 return emitOpError() << "on array operand requires index operand " << i
840 << " to be integer-like, but got " << type;
841 }
842 }
843 // Check element type.
844 Type elementType = arrayType.getElementType();
845 if (elementType != getType()) {
846 return emitOpError() << "on array operand requires element type ("
847 << elementType << ") and result type (" << getType()
848 << ") to match";
849 }
850 return success();
851 }
852
853 // Checks for pointer operand.
854 if (auto pointerType =
855 llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
856 // Check number of indices.
857 if (getIndices().size() != 1) {
858 return emitOpError()
859 << "on pointer operand requires one index operand, but got "
860 << getIndices().size();
861 }
862 // Check types of index operand.
863 Type type = getIndices()[0].getType();
864 if (!isIntegerIndexOrOpaqueType(type)) {
865 return emitOpError() << "on pointer operand requires index operand to be "
866 "integer-like, but got "
867 << type;
868 }
869 // Check pointee type.
870 Type pointeeType = pointerType.getPointee();
871 if (pointeeType != getType()) {
872 return emitOpError() << "on pointer operand requires pointee type ("
873 << pointeeType << ") and result type (" << getType()
874 << ") to match";
875 }
876 return success();
877 }
878
879 // The operand has opaque type, so we can't assume anything about the number
880 // or types of index operands.
881 return success();
882}
883
884//===----------------------------------------------------------------------===//
885// EmitC Enums
886//===----------------------------------------------------------------------===//
887
888#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
889
890//===----------------------------------------------------------------------===//
891// EmitC Attributes
892//===----------------------------------------------------------------------===//
893
894#define GET_ATTRDEF_CLASSES
895#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
896
897//===----------------------------------------------------------------------===//
898// EmitC Types
899//===----------------------------------------------------------------------===//
900
901#define GET_TYPEDEF_CLASSES
902#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
903
904//===----------------------------------------------------------------------===//
905// ArrayType
906//===----------------------------------------------------------------------===//
907
908Type emitc::ArrayType::parse(AsmParser &parser) {
909 if (parser.parseLess())
910 return Type();
911
912 SmallVector<int64_t, 4> dimensions;
913 if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
914 /*withTrailingX=*/true))
915 return Type();
916 // Parse the element type.
917 auto typeLoc = parser.getCurrentLocation();
918 Type elementType;
919 if (parser.parseType(elementType))
920 return Type();
921
922 // Check that array is formed from allowed types.
923 if (!isValidElementType(elementType))
924 return parser.emitError(typeLoc, "invalid array element type"), Type();
925 if (parser.parseGreater())
926 return Type();
927 return parser.getChecked<ArrayType>(dimensions, elementType);
928}
929
930void emitc::ArrayType::print(AsmPrinter &printer) const {
931 printer << "<";
932 for (int64_t dim : getShape()) {
933 printer << dim << 'x';
934 }
935 printer.printType(getElementType());
936 printer << ">";
937}
938
939LogicalResult emitc::ArrayType::verify(
940 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
941 ::llvm::ArrayRef<int64_t> shape, Type elementType) {
942 if (shape.empty())
943 return emitError() << "shape must not be empty";
944
945 for (int64_t dim : shape) {
946 if (dim <= 0)
947 return emitError() << "dimensions must have positive size";
948 }
949
950 if (!elementType)
951 return emitError() << "element type must not be none";
952
953 if (!isValidElementType(elementType))
954 return emitError() << "invalid array element type";
955
956 return success();
957}
958
959emitc::ArrayType
960emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
961 Type elementType) const {
962 if (!shape)
963 return emitc::ArrayType::get(getShape(), elementType);
964 return emitc::ArrayType::get(*shape, elementType);
965}
966
967//===----------------------------------------------------------------------===//
968// OpaqueType
969//===----------------------------------------------------------------------===//
970
971LogicalResult mlir::emitc::OpaqueType::verify(
972 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
973 llvm::StringRef value) {
974 if (value.empty()) {
975 return emitError() << "expected non empty string in !emitc.opaque type";
976 }
977 if (value.back() == '*') {
978 return emitError() << "pointer not allowed as outer type with "
979 "!emitc.opaque, use !emitc.ptr instead";
980 }
981 return success();
982}
983
984//===----------------------------------------------------------------------===//
985// GlobalOp
986//===----------------------------------------------------------------------===//
987static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
988 TypeAttr type,
989 Attribute initialValue) {
990 p << type;
991 if (initialValue) {
992 p << " = ";
993 p.printAttributeWithoutType(attr: initialValue);
994 }
995}
996
997static Type getInitializerTypeForGlobal(Type type) {
998 if (auto array = llvm::dyn_cast<ArrayType>(type))
999 return RankedTensorType::get(array.getShape(), array.getElementType());
1000 return type;
1001}
1002
1003static ParseResult
1004parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1005 Attribute &initialValue) {
1006 Type type;
1007 if (parser.parseType(result&: type))
1008 return failure();
1009
1010 typeAttr = TypeAttr::get(type);
1011
1012 if (parser.parseOptionalEqual())
1013 return success();
1014
1015 if (parser.parseAttribute(result&: initialValue, type: getInitializerTypeForGlobal(type)))
1016 return failure();
1017
1018 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1019 initialValue))
1020 return parser.emitError(loc: parser.getNameLoc())
1021 << "initial value should be a integer, float, elements or opaque "
1022 "attribute";
1023 return success();
1024}
1025
1026LogicalResult GlobalOp::verify() {
1027 if (!isSupportedEmitCType(getType())) {
1028 return emitOpError("expected valid emitc type");
1029 }
1030 if (getInitialValue().has_value()) {
1031 Attribute initValue = getInitialValue().value();
1032 // Check that the type of the initial value is compatible with the type of
1033 // the global variable.
1034 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1035 auto arrayType = llvm::dyn_cast<ArrayType>(getType());
1036 if (!arrayType)
1037 return emitOpError("expected array type, but got ") << getType();
1038
1039 Type initType = elementsAttr.getType();
1040 Type tensorType = getInitializerTypeForGlobal(getType());
1041 if (initType != tensorType) {
1042 return emitOpError("initial value expected to be of type ")
1043 << getType() << ", but was of type " << initType;
1044 }
1045 } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1046 if (intAttr.getType() != getType()) {
1047 return emitOpError("initial value expected to be of type ")
1048 << getType() << ", but was of type " << intAttr.getType();
1049 }
1050 } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1051 if (floatAttr.getType() != getType()) {
1052 return emitOpError("initial value expected to be of type ")
1053 << getType() << ", but was of type " << floatAttr.getType();
1054 }
1055 } else if (!isa<emitc::OpaqueAttr>(initValue)) {
1056 return emitOpError("initial value should be a integer, float, elements "
1057 "or opaque attribute, but got ")
1058 << initValue;
1059 }
1060 }
1061 if (getStaticSpecifier() && getExternSpecifier()) {
1062 return emitOpError("cannot have both static and extern specifiers");
1063 }
1064 return success();
1065}
1066
1067//===----------------------------------------------------------------------===//
1068// GetGlobalOp
1069//===----------------------------------------------------------------------===//
1070
1071LogicalResult
1072GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1073 // Verify that the type matches the type of the global variable.
1074 auto global =
1075 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1076 if (!global)
1077 return emitOpError("'")
1078 << getName() << "' does not reference a valid emitc.global";
1079
1080 Type resultType = getResult().getType();
1081 if (global.getType() != resultType)
1082 return emitOpError("result type ")
1083 << resultType << " does not match type " << global.getType()
1084 << " of the global @" << getName();
1085 return success();
1086}
1087
1088//===----------------------------------------------------------------------===//
1089// TableGen'd op method definitions
1090//===----------------------------------------------------------------------===//
1091
1092#define GET_OP_CLASSES
1093#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
1094

source code of mlir/lib/Dialect/EmitC/IR/EmitC.cpp