1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "clang/CIR/Dialect/IR/CIRDialect.h"
14
15#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
16#include "clang/CIR/Dialect/IR/CIRTypes.h"
17
18#include "mlir/Interfaces/ControlFlowInterfaces.h"
19#include "mlir/Interfaces/FunctionImplementation.h"
20
21#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
22#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
23#include "clang/CIR/MissingFeatures.h"
24
25#include <numeric>
26
27using namespace mlir;
28using namespace cir;
29
30//===----------------------------------------------------------------------===//
31// CIR Dialect
32//===----------------------------------------------------------------------===//
33namespace {
34struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
35 using OpAsmDialectInterface::OpAsmDialectInterface;
36
37 AliasResult getAlias(Type type, raw_ostream &os) const final {
38 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
39 StringAttr nameAttr = recordType.getName();
40 if (!nameAttr)
41 os << "rec_anon_" << recordType.getKindAsStr();
42 else
43 os << "rec_" << nameAttr.getValue();
44 return AliasResult::OverridableAlias;
45 }
46 if (auto intType = dyn_cast<cir::IntType>(type)) {
47 // We only provide alias for standard integer types (i.e. integer types
48 // whose width is a power of 2 and at least 8).
49 unsigned width = intType.getWidth();
50 if (width < 8 || !llvm::isPowerOf2_32(Value: width))
51 return AliasResult::NoAlias;
52 os << intType.getAlias();
53 return AliasResult::OverridableAlias;
54 }
55 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
56 os << voidType.getAlias();
57 return AliasResult::OverridableAlias;
58 }
59
60 return AliasResult::NoAlias;
61 }
62
63 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
64 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
65 os << (boolAttr.getValue() ? "true" : "false");
66 return AliasResult::FinalAlias;
67 }
68 return AliasResult::NoAlias;
69 }
70};
71} // namespace
72
73void cir::CIRDialect::initialize() {
74 registerTypes();
75 registerAttributes();
76 addOperations<
77#define GET_OP_LIST
78#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
79 >();
80 addInterfaces<CIROpAsmDialectInterface>();
81}
82
83Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
84 mlir::Attribute value,
85 mlir::Type type,
86 mlir::Location loc) {
87 return builder.create<cir::ConstantOp>(loc, type,
88 mlir::cast<mlir::TypedAttr>(value));
89}
90
91//===----------------------------------------------------------------------===//
92// Helpers
93//===----------------------------------------------------------------------===//
94
95// Check if a region's termination omission is valid and, if so, creates and
96// inserts the omitted terminator into the region.
97static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
98 SMLoc errLoc) {
99 Location eLoc = parser.getEncodedSourceLoc(loc: parser.getCurrentLocation());
100 OpBuilder builder(parser.getBuilder().getContext());
101
102 // Insert empty block in case the region is empty to ensure the terminator
103 // will be inserted
104 if (region.empty())
105 builder.createBlock(parent: &region);
106
107 Block &block = region.back();
108 // Region is properly terminated: nothing to do.
109 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
110 return success();
111
112 // Check for invalid terminator omissions.
113 if (!region.hasOneBlock())
114 return parser.emitError(loc: errLoc,
115 message: "multi-block region must not omit terminator");
116
117 // Terminator was omitted correctly: recreate it.
118 builder.setInsertionPointToEnd(&block);
119 builder.create<cir::YieldOp>(eLoc);
120 return success();
121}
122
123// True if the region's terminator should be omitted.
124static bool omitRegionTerm(mlir::Region &r) {
125 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
126 const auto yieldsNothing = [&r]() {
127 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
128 return y && y.getArgs().empty();
129 };
130 return singleNonEmptyBlock && yieldsNothing();
131}
132
133//===----------------------------------------------------------------------===//
134// CIR Custom Parsers/Printers
135//===----------------------------------------------------------------------===//
136
137static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
138 mlir::Region &region) {
139 auto regionLoc = parser.getCurrentLocation();
140 if (parser.parseRegion(region))
141 return failure();
142 if (ensureRegionTerm(parser, region, errLoc: regionLoc).failed())
143 return failure();
144 return success();
145}
146
147static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
148 cir::ScopeOp &op,
149 mlir::Region &region) {
150 printer.printRegion(blocks&: region,
151 /*printEntryBlockArgs=*/false,
152 /*printBlockTerminators=*/!omitRegionTerm(r&: region));
153}
154
155//===----------------------------------------------------------------------===//
156// AllocaOp
157//===----------------------------------------------------------------------===//
158
159void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
160 mlir::OperationState &odsState, mlir::Type addr,
161 mlir::Type allocaType, llvm::StringRef name,
162 mlir::IntegerAttr alignment) {
163 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
164 mlir::TypeAttr::get(allocaType));
165 odsState.addAttribute(getNameAttrName(odsState.name),
166 odsBuilder.getStringAttr(name));
167 if (alignment) {
168 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
169 }
170 odsState.addTypes(addr);
171}
172
173//===----------------------------------------------------------------------===//
174// BreakOp
175//===----------------------------------------------------------------------===//
176
177LogicalResult cir::BreakOp::verify() {
178 assert(!cir::MissingFeatures::switchOp());
179 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
180 !getOperation()->getParentOfType<SwitchOp>())
181 return emitOpError("must be within a loop");
182 return success();
183}
184
185//===----------------------------------------------------------------------===//
186// ConditionOp
187//===----------------------------------------------------------------------===//
188
189//===----------------------------------
190// BranchOpTerminatorInterface Methods
191//===----------------------------------
192
193void cir::ConditionOp::getSuccessorRegions(
194 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
195 // TODO(cir): The condition value may be folded to a constant, narrowing
196 // down its list of possible successors.
197
198 // Parent is a loop: condition may branch to the body or to the parent op.
199 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
200 regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
201 regions.emplace_back(loopOp->getResults());
202 }
203
204 assert(!cir::MissingFeatures::awaitOp());
205}
206
207MutableOperandRange
208cir::ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
209 // No values are yielded to the successor region.
210 return MutableOperandRange(getOperation(), 0, 0);
211}
212
213LogicalResult cir::ConditionOp::verify() {
214 assert(!cir::MissingFeatures::awaitOp());
215 if (!isa<LoopOpInterface>(getOperation()->getParentOp()))
216 return emitOpError("condition must be within a conditional region");
217 return success();
218}
219
220//===----------------------------------------------------------------------===//
221// ConstantOp
222//===----------------------------------------------------------------------===//
223
224static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
225 mlir::Attribute attrType) {
226 if (isa<cir::ConstPtrAttr>(attrType)) {
227 if (!mlir::isa<cir::PointerType>(opType))
228 return op->emitOpError(
229 message: "pointer constant initializing a non-pointer type");
230 return success();
231 }
232
233 if (isa<cir::ZeroAttr>(attrType)) {
234 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
235 opType))
236 return success();
237 return op->emitOpError(
238 message: "zero expects struct, array, vector, or complex type");
239 }
240
241 if (mlir::isa<cir::BoolAttr>(attrType)) {
242 if (!mlir::isa<cir::BoolType>(opType))
243 return op->emitOpError(message: "result type (")
244 << opType << ") must be '!cir.bool' for '" << attrType << "'";
245 return success();
246 }
247
248 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
249 auto at = cast<TypedAttr>(attrType);
250 if (at.getType() != opType) {
251 return op->emitOpError(message: "result type (")
252 << opType << ") does not match value type (" << at.getType()
253 << ")";
254 }
255 return success();
256 }
257
258 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
259 cir::ConstComplexAttr>(attrType))
260 return success();
261
262 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
263 return op->emitOpError("global with type ")
264 << cast<TypedAttr>(attrType).getType() << " not yet supported";
265}
266
267LogicalResult cir::ConstantOp::verify() {
268 // ODS already generates checks to make sure the result type is valid. We just
269 // need to additionally check that the value's attribute type is consistent
270 // with the result type.
271 return checkConstantTypes(getOperation(), getType(), getValue());
272}
273
274OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
275 return getValue();
276}
277
278//===----------------------------------------------------------------------===//
279// ContinueOp
280//===----------------------------------------------------------------------===//
281
282LogicalResult cir::ContinueOp::verify() {
283 if (!getOperation()->getParentOfType<LoopOpInterface>())
284 return emitOpError("must be within a loop");
285 return success();
286}
287
288//===----------------------------------------------------------------------===//
289// CastOp
290//===----------------------------------------------------------------------===//
291
292LogicalResult cir::CastOp::verify() {
293 mlir::Type resType = getType();
294 mlir::Type srcType = getSrc().getType();
295
296 if (mlir::isa<cir::VectorType>(srcType) &&
297 mlir::isa<cir::VectorType>(resType)) {
298 // Use the element type of the vector to verify the cast kind. (Except for
299 // bitcast, see below.)
300 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
301 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
302 }
303
304 switch (getKind()) {
305 case cir::CastKind::int_to_bool: {
306 if (!mlir::isa<cir::BoolType>(resType))
307 return emitOpError() << "requires !cir.bool type for result";
308 if (!mlir::isa<cir::IntType>(srcType))
309 return emitOpError() << "requires !cir.int type for source";
310 return success();
311 }
312 case cir::CastKind::ptr_to_bool: {
313 if (!mlir::isa<cir::BoolType>(resType))
314 return emitOpError() << "requires !cir.bool type for result";
315 if (!mlir::isa<cir::PointerType>(srcType))
316 return emitOpError() << "requires !cir.ptr type for source";
317 return success();
318 }
319 case cir::CastKind::integral: {
320 if (!mlir::isa<cir::IntType>(resType))
321 return emitOpError() << "requires !cir.int type for result";
322 if (!mlir::isa<cir::IntType>(srcType))
323 return emitOpError() << "requires !cir.int type for source";
324 return success();
325 }
326 case cir::CastKind::array_to_ptrdecay: {
327 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
328 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
329 if (!arrayPtrTy || !flatPtrTy)
330 return emitOpError() << "requires !cir.ptr type for source and result";
331
332 // TODO(CIR): Make sure the AddrSpace of both types are equals
333 return success();
334 }
335 case cir::CastKind::bitcast: {
336 // Handle the pointer types first.
337 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
338 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
339
340 if (srcPtrTy && resPtrTy) {
341 return success();
342 }
343
344 return success();
345 }
346 case cir::CastKind::floating: {
347 if (!mlir::isa<cir::CIRFPTypeInterface>(srcType) ||
348 !mlir::isa<cir::CIRFPTypeInterface>(resType))
349 return emitOpError() << "requires !cir.float type for source and result";
350 return success();
351 }
352 case cir::CastKind::float_to_int: {
353 if (!mlir::isa<cir::CIRFPTypeInterface>(srcType))
354 return emitOpError() << "requires !cir.float type for source";
355 if (!mlir::dyn_cast<cir::IntType>(resType))
356 return emitOpError() << "requires !cir.int type for result";
357 return success();
358 }
359 case cir::CastKind::int_to_ptr: {
360 if (!mlir::dyn_cast<cir::IntType>(srcType))
361 return emitOpError() << "requires !cir.int type for source";
362 if (!mlir::dyn_cast<cir::PointerType>(resType))
363 return emitOpError() << "requires !cir.ptr type for result";
364 return success();
365 }
366 case cir::CastKind::ptr_to_int: {
367 if (!mlir::dyn_cast<cir::PointerType>(srcType))
368 return emitOpError() << "requires !cir.ptr type for source";
369 if (!mlir::dyn_cast<cir::IntType>(resType))
370 return emitOpError() << "requires !cir.int type for result";
371 return success();
372 }
373 case cir::CastKind::float_to_bool: {
374 if (!mlir::isa<cir::CIRFPTypeInterface>(srcType))
375 return emitOpError() << "requires !cir.float type for source";
376 if (!mlir::isa<cir::BoolType>(resType))
377 return emitOpError() << "requires !cir.bool type for result";
378 return success();
379 }
380 case cir::CastKind::bool_to_int: {
381 if (!mlir::isa<cir::BoolType>(srcType))
382 return emitOpError() << "requires !cir.bool type for source";
383 if (!mlir::isa<cir::IntType>(resType))
384 return emitOpError() << "requires !cir.int type for result";
385 return success();
386 }
387 case cir::CastKind::int_to_float: {
388 if (!mlir::isa<cir::IntType>(srcType))
389 return emitOpError() << "requires !cir.int type for source";
390 if (!mlir::isa<cir::CIRFPTypeInterface>(resType))
391 return emitOpError() << "requires !cir.float type for result";
392 return success();
393 }
394 case cir::CastKind::bool_to_float: {
395 if (!mlir::isa<cir::BoolType>(srcType))
396 return emitOpError() << "requires !cir.bool type for source";
397 if (!mlir::isa<cir::CIRFPTypeInterface>(resType))
398 return emitOpError() << "requires !cir.float type for result";
399 return success();
400 }
401 case cir::CastKind::address_space: {
402 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
403 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
404 if (!srcPtrTy || !resPtrTy)
405 return emitOpError() << "requires !cir.ptr type for source and result";
406 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
407 return emitOpError() << "requires two types differ in addrspace only";
408 return success();
409 }
410 default:
411 llvm_unreachable("Unknown CastOp kind?");
412 }
413}
414
415static bool isIntOrBoolCast(cir::CastOp op) {
416 auto kind = op.getKind();
417 return kind == cir::CastKind::bool_to_int ||
418 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
419}
420
421static Value tryFoldCastChain(cir::CastOp op) {
422 cir::CastOp head = op, tail = op;
423
424 while (op) {
425 if (!isIntOrBoolCast(op))
426 break;
427 head = op;
428 op = dyn_cast_or_null<cir::CastOp>(head.getSrc().getDefiningOp());
429 }
430
431 if (head == tail)
432 return {};
433
434 // if bool_to_int -> ... -> int_to_bool: take the bool
435 // as we had it was before all casts
436 if (head.getKind() == cir::CastKind::bool_to_int &&
437 tail.getKind() == cir::CastKind::int_to_bool)
438 return head.getSrc();
439
440 // if int_to_bool -> ... -> int_to_bool: take the result
441 // of the first one, as no other casts (and ext casts as well)
442 // don't change the first result
443 if (head.getKind() == cir::CastKind::int_to_bool &&
444 tail.getKind() == cir::CastKind::int_to_bool)
445 return head.getResult();
446
447 return {};
448}
449
450OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
451 if (getSrc().getType() == getType()) {
452 switch (getKind()) {
453 case cir::CastKind::integral: {
454 // TODO: for sign differences, it's possible in certain conditions to
455 // create a new attribute that's capable of representing the source.
456 llvm::SmallVector<mlir::OpFoldResult, 1> foldResults;
457 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
458 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
459 return mlir::cast<mlir::Attribute>(foldResults[0]);
460 return {};
461 }
462 case cir::CastKind::bitcast:
463 case cir::CastKind::address_space:
464 case cir::CastKind::float_complex:
465 case cir::CastKind::int_complex: {
466 return getSrc();
467 }
468 default:
469 return {};
470 }
471 }
472 return tryFoldCastChain(*this);
473}
474
475//===----------------------------------------------------------------------===//
476// CallOp
477//===----------------------------------------------------------------------===//
478
479mlir::OperandRange cir::CallOp::getArgOperands() {
480 if (isIndirect())
481 return getArgs().drop_front(1);
482 return getArgs();
483}
484
485mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
486 mlir::MutableOperandRange args = getArgsMutable();
487 if (isIndirect())
488 return args.slice(1, args.size() - 1);
489 return args;
490}
491
492mlir::Value cir::CallOp::getIndirectCall() {
493 assert(isIndirect());
494 return getOperand(0);
495}
496
497/// Return the operand at index 'i'.
498Value cir::CallOp::getArgOperand(unsigned i) {
499 if (isIndirect())
500 ++i;
501 return getOperand(i);
502}
503
504/// Return the number of operands.
505unsigned cir::CallOp::getNumArgOperands() {
506 if (isIndirect())
507 return this->getOperation()->getNumOperands() - 1;
508 return this->getOperation()->getNumOperands();
509}
510
511static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
512 mlir::OperationState &result) {
513 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
514 llvm::SMLoc opsLoc;
515 mlir::FlatSymbolRefAttr calleeAttr;
516 llvm::ArrayRef<mlir::Type> allResultTypes;
517
518 // If we cannot parse a string callee, it means this is an indirect call.
519 if (!parser.parseOptionalAttribute(result&: calleeAttr, attrName: "callee", attrs&: result.attributes)
520 .has_value()) {
521 OpAsmParser::UnresolvedOperand indirectVal;
522 // Do not resolve right now, since we need to figure out the type
523 if (parser.parseOperand(result&: indirectVal).failed())
524 return failure();
525 ops.push_back(Elt: indirectVal);
526 }
527
528 if (parser.parseLParen())
529 return mlir::failure();
530
531 opsLoc = parser.getCurrentLocation();
532 if (parser.parseOperandList(result&: ops))
533 return mlir::failure();
534 if (parser.parseRParen())
535 return mlir::failure();
536
537 if (parser.parseOptionalAttrDict(result&: result.attributes))
538 return ::mlir::failure();
539
540 if (parser.parseColon())
541 return ::mlir::failure();
542
543 mlir::FunctionType opsFnTy;
544 if (parser.parseType(opsFnTy))
545 return mlir::failure();
546
547 allResultTypes = opsFnTy.getResults();
548 result.addTypes(newTypes: allResultTypes);
549
550 if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
551 return mlir::failure();
552
553 return mlir::success();
554}
555
556static void printCallCommon(mlir::Operation *op,
557 mlir::FlatSymbolRefAttr calleeSym,
558 mlir::Value indirectCallee,
559 mlir::OpAsmPrinter &printer) {
560 printer << ' ';
561
562 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
563 auto ops = callLikeOp.getArgOperands();
564
565 if (calleeSym) {
566 // Direct calls
567 printer.printAttributeWithoutType(calleeSym);
568 } else {
569 // Indirect calls
570 assert(indirectCallee);
571 printer << indirectCallee;
572 }
573 printer << "(" << ops << ")";
574
575 printer.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs: {"callee"});
576
577 printer << " : ";
578 printer.printFunctionalType(inputs: op->getOperands().getTypes(),
579 results: op->getResultTypes());
580}
581
582mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
583 mlir::OperationState &result) {
584 return parseCallCommon(parser, result);
585}
586
587void cir::CallOp::print(mlir::OpAsmPrinter &p) {
588 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
589 printCallCommon(*this, getCalleeAttr(), indirectCallee, p);
590}
591
592static LogicalResult
593verifyCallCommInSymbolUses(mlir::Operation *op,
594 SymbolTableCollection &symbolTable) {
595 auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>(name: "callee");
596 if (!fnAttr) {
597 // This is an indirect call, thus we don't have to check the symbol uses.
598 return mlir::success();
599 }
600
601 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
602 if (!fn)
603 return op->emitOpError() << "'" << fnAttr.getValue()
604 << "' does not reference a valid function";
605
606 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
607 assert(callIf && "expected CIR call interface to be always available");
608
609 // Verify that the operand and result types match the callee. Note that
610 // argument-checking is disabled for functions without a prototype.
611 auto fnType = fn.getFunctionType();
612 if (!fn.getNoProto()) {
613 unsigned numCallOperands = callIf.getNumArgOperands();
614 unsigned numFnOpOperands = fnType.getNumInputs();
615
616 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
617 return op->emitOpError(message: "incorrect number of operands for callee");
618 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
619 return op->emitOpError(message: "too few operands for callee");
620
621 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
622 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
623 return op->emitOpError(message: "operand type mismatch: expected operand type ")
624 << fnType.getInput(i) << ", but provided "
625 << op->getOperand(idx: i).getType() << " for operand number " << i;
626 }
627
628 assert(!cir::MissingFeatures::opCallCallConv());
629
630 // Void function must not return any results.
631 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
632 return op->emitOpError(message: "callee returns void but call has results");
633
634 // Non-void function calls must return exactly one result.
635 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
636 return op->emitOpError(message: "incorrect number of results for callee");
637
638 // Parent function and return value types must match.
639 if (!fnType.hasVoidReturn() &&
640 op->getResultTypes().front() != fnType.getReturnType()) {
641 return op->emitOpError(message: "result type mismatch: expected ")
642 << fnType.getReturnType() << ", but provided "
643 << op->getResult(idx: 0).getType();
644 }
645
646 return mlir::success();
647}
648
649LogicalResult
650cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
651 return verifyCallCommInSymbolUses(*this, symbolTable);
652}
653
654//===----------------------------------------------------------------------===//
655// ReturnOp
656//===----------------------------------------------------------------------===//
657
658static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
659 cir::FuncOp function) {
660 // ReturnOps currently only have a single optional operand.
661 if (op.getNumOperands() > 1)
662 return op.emitOpError() << "expects at most 1 return operand";
663
664 // Ensure returned type matches the function signature.
665 auto expectedTy = function.getFunctionType().getReturnType();
666 auto actualTy =
667 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
668 : op.getOperand(0).getType());
669 if (actualTy != expectedTy)
670 return op.emitOpError() << "returns " << actualTy
671 << " but enclosing function returns " << expectedTy;
672
673 return mlir::success();
674}
675
676mlir::LogicalResult cir::ReturnOp::verify() {
677 // Returns can be present in multiple different scopes, get the
678 // wrapping function and start from there.
679 auto *fnOp = getOperation()->getParentOp();
680 while (!isa<cir::FuncOp>(fnOp))
681 fnOp = fnOp->getParentOp();
682
683 // Make sure return types match function return type.
684 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
685 return failure();
686
687 return success();
688}
689
690//===----------------------------------------------------------------------===//
691// IfOp
692//===----------------------------------------------------------------------===//
693
694ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
695 // create the regions for 'then'.
696 result.regions.reserve(2);
697 Region *thenRegion = result.addRegion();
698 Region *elseRegion = result.addRegion();
699
700 mlir::Builder &builder = parser.getBuilder();
701 OpAsmParser::UnresolvedOperand cond;
702 Type boolType = cir::BoolType::get(builder.getContext());
703
704 if (parser.parseOperand(cond) ||
705 parser.resolveOperand(cond, boolType, result.operands))
706 return failure();
707
708 // Parse 'then' region.
709 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
710 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
711 return failure();
712
713 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
714 return failure();
715
716 // If we find an 'else' keyword, parse the 'else' region.
717 if (!parser.parseOptionalKeyword("else")) {
718 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
719 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
720 return failure();
721 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
722 return failure();
723 }
724
725 // Parse the optional attribute list.
726 if (parser.parseOptionalAttrDict(result.attributes))
727 return failure();
728 return success();
729}
730
731void cir::IfOp::print(OpAsmPrinter &p) {
732 p << " " << getCondition() << " ";
733 mlir::Region &thenRegion = this->getThenRegion();
734 p.printRegion(thenRegion,
735 /*printEntryBlockArgs=*/false,
736 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
737
738 // Print the 'else' regions if it exists and has a block.
739 mlir::Region &elseRegion = this->getElseRegion();
740 if (!elseRegion.empty()) {
741 p << " else ";
742 p.printRegion(elseRegion,
743 /*printEntryBlockArgs=*/false,
744 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
745 }
746
747 p.printOptionalAttrDict(getOperation()->getAttrs());
748}
749
750/// Default callback for IfOp builders.
751void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
752 // add cir.yield to end of the block
753 builder.create<cir::YieldOp>(loc);
754}
755
756/// Given the region at `index`, or the parent operation if `index` is None,
757/// return the successor regions. These are the regions that may be selected
758/// during the flow of control. `operands` is a set of optional attributes that
759/// correspond to a constant value for each operand, or null if that operand is
760/// not a constant.
761void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
762 SmallVectorImpl<RegionSuccessor> &regions) {
763 // The `then` and the `else` region branch back to the parent operation.
764 if (!point.isParent()) {
765 regions.push_back(RegionSuccessor());
766 return;
767 }
768
769 // Don't consider the else region if it is empty.
770 Region *elseRegion = &this->getElseRegion();
771 if (elseRegion->empty())
772 elseRegion = nullptr;
773
774 // If the condition isn't constant, both regions may be executed.
775 regions.push_back(RegionSuccessor(&getThenRegion()));
776 // If the else region does not exist, it is not a viable successor.
777 if (elseRegion)
778 regions.push_back(RegionSuccessor(elseRegion));
779
780 return;
781}
782
783void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
784 bool withElseRegion, BuilderCallbackRef thenBuilder,
785 BuilderCallbackRef elseBuilder) {
786 assert(thenBuilder && "the builder callback for 'then' must be present");
787 result.addOperands(cond);
788
789 OpBuilder::InsertionGuard guard(builder);
790 Region *thenRegion = result.addRegion();
791 builder.createBlock(thenRegion);
792 thenBuilder(builder, result.location);
793
794 Region *elseRegion = result.addRegion();
795 if (!withElseRegion)
796 return;
797
798 builder.createBlock(elseRegion);
799 elseBuilder(builder, result.location);
800}
801
802//===----------------------------------------------------------------------===//
803// ScopeOp
804//===----------------------------------------------------------------------===//
805
806/// Given the region at `index`, or the parent operation if `index` is None,
807/// return the successor regions. These are the regions that may be selected
808/// during the flow of control. `operands` is a set of optional attributes
809/// that correspond to a constant value for each operand, or null if that
810/// operand is not a constant.
811void cir::ScopeOp::getSuccessorRegions(
812 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
813 // The only region always branch back to the parent operation.
814 if (!point.isParent()) {
815 regions.push_back(RegionSuccessor(getODSResults(0)));
816 return;
817 }
818
819 // If the condition isn't constant, both regions may be executed.
820 regions.push_back(RegionSuccessor(&getScopeRegion()));
821}
822
823void cir::ScopeOp::build(
824 OpBuilder &builder, OperationState &result,
825 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
826 assert(scopeBuilder && "the builder callback for 'then' must be present");
827
828 OpBuilder::InsertionGuard guard(builder);
829 Region *scopeRegion = result.addRegion();
830 builder.createBlock(scopeRegion);
831 assert(!cir::MissingFeatures::opScopeCleanupRegion());
832
833 mlir::Type yieldTy;
834 scopeBuilder(builder, yieldTy, result.location);
835
836 if (yieldTy)
837 result.addTypes(TypeRange{yieldTy});
838}
839
840void cir::ScopeOp::build(
841 OpBuilder &builder, OperationState &result,
842 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
843 assert(scopeBuilder && "the builder callback for 'then' must be present");
844 OpBuilder::InsertionGuard guard(builder);
845 Region *scopeRegion = result.addRegion();
846 builder.createBlock(scopeRegion);
847 assert(!cir::MissingFeatures::opScopeCleanupRegion());
848 scopeBuilder(builder, result.location);
849}
850
851LogicalResult cir::ScopeOp::verify() {
852 if (getRegion().empty()) {
853 return emitOpError() << "cir.scope must not be empty since it should "
854 "include at least an implicit cir.yield ";
855 }
856
857 mlir::Block &lastBlock = getRegion().back();
858 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
859 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
860 return emitOpError() << "last block of cir.scope must be terminated";
861 return success();
862}
863
864//===----------------------------------------------------------------------===//
865// BrOp
866//===----------------------------------------------------------------------===//
867
868mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
869 assert(index == 0 && "invalid successor index");
870 return mlir::SuccessorOperands(getDestOperandsMutable());
871}
872
873Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
874 return getDest();
875}
876
877//===----------------------------------------------------------------------===//
878// BrCondOp
879//===----------------------------------------------------------------------===//
880
881mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
882 assert(index < getNumSuccessors() && "invalid successor index");
883 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
884 : getDestOperandsFalseMutable());
885}
886
887Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
888 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
889 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
890 return nullptr;
891}
892
893//===----------------------------------------------------------------------===//
894// CaseOp
895//===----------------------------------------------------------------------===//
896
897void cir::CaseOp::getSuccessorRegions(
898 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
899 if (!point.isParent()) {
900 regions.push_back(RegionSuccessor());
901 return;
902 }
903 regions.push_back(RegionSuccessor(&getCaseRegion()));
904}
905
906void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
907 ArrayAttr value, CaseOpKind kind,
908 OpBuilder::InsertPoint &insertPoint) {
909 OpBuilder::InsertionGuard guardSwitch(builder);
910 result.addAttribute("value", value);
911 result.getOrAddProperties<Properties>().kind =
912 cir::CaseOpKindAttr::get(builder.getContext(), kind);
913 Region *caseRegion = result.addRegion();
914 builder.createBlock(caseRegion);
915
916 insertPoint = builder.saveInsertionPoint();
917}
918
919//===----------------------------------------------------------------------===//
920// SwitchOp
921//===----------------------------------------------------------------------===//
922
923static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
924 mlir::OpAsmParser::UnresolvedOperand &cond,
925 mlir::Type &condType) {
926 cir::IntType intCondType;
927
928 if (parser.parseLParen())
929 return mlir::failure();
930
931 if (parser.parseOperand(result&: cond))
932 return mlir::failure();
933 if (parser.parseColon())
934 return mlir::failure();
935 if (parser.parseCustomTypeWithFallback(intCondType))
936 return mlir::failure();
937 condType = intCondType;
938
939 if (parser.parseRParen())
940 return mlir::failure();
941 if (parser.parseRegion(region&: regions, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
942 return failure();
943
944 return mlir::success();
945}
946
947static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
948 mlir::Region &bodyRegion, mlir::Value condition,
949 mlir::Type condType) {
950 p << "(";
951 p << condition;
952 p << " : ";
953 p.printStrippedAttrOrType(attrOrType: condType);
954 p << ")";
955
956 p << ' ';
957 p.printRegion(blocks&: bodyRegion, /*printEntryBlockArgs=*/false,
958 /*printBlockTerminators=*/true);
959}
960
961void cir::SwitchOp::getSuccessorRegions(
962 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
963 if (!point.isParent()) {
964 region.push_back(RegionSuccessor());
965 return;
966 }
967
968 region.push_back(RegionSuccessor(&getBody()));
969}
970
971void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
972 Value cond, BuilderOpStateCallbackRef switchBuilder) {
973 assert(switchBuilder && "the builder callback for regions must be present");
974 OpBuilder::InsertionGuard guardSwitch(builder);
975 Region *switchRegion = result.addRegion();
976 builder.createBlock(switchRegion);
977 result.addOperands({cond});
978 switchBuilder(builder, result.location, result);
979}
980
981void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
982 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
983 // Don't walk in nested switch op.
984 if (isa<cir::SwitchOp>(op) && op != *this)
985 return WalkResult::skip();
986
987 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
988 cases.push_back(caseOp);
989
990 return WalkResult::advance();
991 });
992}
993
994bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
995 collectCases(cases);
996
997 if (getBody().empty())
998 return false;
999
1000 if (!isa<YieldOp>(getBody().front().back()))
1001 return false;
1002
1003 if (!llvm::all_of(getBody().front(),
1004 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1005 return false;
1006
1007 return llvm::all_of(cases, [this](CaseOp op) {
1008 return op->getParentOfType<SwitchOp>() == *this;
1009 });
1010}
1011
1012//===----------------------------------------------------------------------===//
1013// SwitchFlatOp
1014//===----------------------------------------------------------------------===//
1015
1016void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1017 Value value, Block *defaultDestination,
1018 ValueRange defaultOperands,
1019 ArrayRef<APInt> caseValues,
1020 BlockRange caseDestinations,
1021 ArrayRef<ValueRange> caseOperands) {
1022
1023 std::vector<mlir::Attribute> caseValuesAttrs;
1024 for (const APInt &val : caseValues)
1025 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1026 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1027
1028 build(builder, result, value, defaultOperands, caseOperands, attrs,
1029 defaultDestination, caseDestinations);
1030}
1031
1032/// <cases> ::= `[` (case (`,` case )* )? `]`
1033/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1034static ParseResult parseSwitchFlatOpCases(
1035 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1036 SmallVectorImpl<Block *> &caseDestinations,
1037 SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
1038 &caseOperands,
1039 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1040 if (failed(Result: parser.parseLSquare()))
1041 return failure();
1042 if (succeeded(Result: parser.parseOptionalRSquare()))
1043 return success();
1044 llvm::SmallVector<mlir::Attribute> values;
1045
1046 auto parseCase = [&]() {
1047 int64_t value = 0;
1048 if (failed(Result: parser.parseInteger(result&: value)))
1049 return failure();
1050
1051 values.push_back(cir::IntAttr::get(flagType, value));
1052
1053 Block *destination;
1054 llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
1055 llvm::SmallVector<Type> operandTypes;
1056 if (parser.parseColon() || parser.parseSuccessor(dest&: destination))
1057 return failure();
1058 if (!parser.parseOptionalLParen()) {
1059 if (parser.parseOperandList(result&: operands, delimiter: OpAsmParser::Delimiter::None,
1060 /*allowResultNumber=*/false) ||
1061 parser.parseColonTypeList(result&: operandTypes) || parser.parseRParen())
1062 return failure();
1063 }
1064 caseDestinations.push_back(Elt: destination);
1065 caseOperands.emplace_back(Args&: operands);
1066 caseOperandTypes.emplace_back(Args&: operandTypes);
1067 return success();
1068 };
1069 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: parseCase)))
1070 return failure();
1071
1072 caseValues = ArrayAttr::get(flagType.getContext(), values);
1073
1074 return parser.parseRSquare();
1075}
1076
1077static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1078 Type flagType, mlir::ArrayAttr caseValues,
1079 SuccessorRange caseDestinations,
1080 OperandRangeRange caseOperands,
1081 const TypeRangeRange &caseOperandTypes) {
1082 p << '[';
1083 p.printNewline();
1084 if (!caseValues) {
1085 p << ']';
1086 return;
1087 }
1088
1089 size_t index = 0;
1090 llvm::interleave(
1091 llvm::zip(caseValues, caseDestinations),
1092 [&](auto i) {
1093 p << " ";
1094 mlir::Attribute a = std::get<0>(i);
1095 p << mlir::cast<cir::IntAttr>(a).getValue();
1096 p << ": ";
1097 p.printSuccessorAndUseList(successor: std::get<1>(i), succOperands: caseOperands[index++]);
1098 },
1099 [&] {
1100 p << ',';
1101 p.printNewline();
1102 });
1103 p.printNewline();
1104 p << ']';
1105}
1106
1107//===----------------------------------------------------------------------===//
1108// GlobalOp
1109//===----------------------------------------------------------------------===//
1110
1111static ParseResult parseConstantValue(OpAsmParser &parser,
1112 mlir::Attribute &valueAttr) {
1113 NamedAttrList attr;
1114 return parser.parseAttribute(result&: valueAttr, attrName: "value", attrs&: attr);
1115}
1116
1117static void printConstant(OpAsmPrinter &p, Attribute value) {
1118 p.printAttribute(attr: value);
1119}
1120
1121mlir::LogicalResult cir::GlobalOp::verify() {
1122 // Verify that the initial value, if present, is either a unit attribute or
1123 // an attribute CIR supports.
1124 if (getInitialValue().has_value()) {
1125 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1126 .failed())
1127 return failure();
1128 }
1129
1130 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1131 // yet.
1132
1133 return success();
1134}
1135
1136void cir::GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1137 llvm::StringRef sym_name, mlir::Type sym_type,
1138 cir::GlobalLinkageKind linkage) {
1139 odsState.addAttribute(getSymNameAttrName(odsState.name),
1140 odsBuilder.getStringAttr(sym_name));
1141 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1142 mlir::TypeAttr::get(sym_type));
1143
1144 cir::GlobalLinkageKindAttr linkageAttr =
1145 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1146 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1147
1148 odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
1149 cir::VisibilityAttr::get(odsBuilder.getContext()));
1150}
1151
1152static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1153 TypeAttr type,
1154 Attribute initAttr) {
1155 if (!op.isDeclaration()) {
1156 p << "= ";
1157 // This also prints the type...
1158 if (initAttr)
1159 printConstant(p, value: initAttr);
1160 } else {
1161 p << ": " << type;
1162 }
1163}
1164
1165static ParseResult
1166parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1167 Attribute &initialValueAttr) {
1168 mlir::Type opTy;
1169 if (parser.parseOptionalEqual().failed()) {
1170 // Absence of equal means a declaration, so we need to parse the type.
1171 // cir.global @a : !cir.int<s, 32>
1172 if (parser.parseColonType(result&: opTy))
1173 return failure();
1174 } else {
1175 // Parse constant with initializer, examples:
1176 // cir.global @y = #cir.fp<1.250000e+00> : !cir.double
1177 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1178 if (parseConstantValue(parser, valueAttr&: initialValueAttr).failed())
1179 return failure();
1180
1181 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1182 "Non-typed attrs shouldn't appear here.");
1183 auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
1184 opTy = typedAttr.getType();
1185 }
1186
1187 typeAttr = TypeAttr::get(opTy);
1188 return success();
1189}
1190
1191//===----------------------------------------------------------------------===//
1192// GetGlobalOp
1193//===----------------------------------------------------------------------===//
1194
1195LogicalResult
1196cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1197 // Verify that the result type underlying pointer type matches the type of
1198 // the referenced cir.global or cir.func op.
1199 mlir::Operation *op =
1200 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
1201 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
1202 return emitOpError("'")
1203 << getName()
1204 << "' does not reference a valid cir.global or cir.func";
1205
1206 mlir::Type symTy;
1207 if (auto g = dyn_cast<GlobalOp>(op)) {
1208 symTy = g.getSymType();
1209 assert(!cir::MissingFeatures::addressSpace());
1210 assert(!cir::MissingFeatures::opGlobalThreadLocal());
1211 } else if (auto f = dyn_cast<FuncOp>(op)) {
1212 symTy = f.getFunctionType();
1213 } else {
1214 llvm_unreachable("Unexpected operation for GetGlobalOp");
1215 }
1216
1217 auto resultType = dyn_cast<PointerType>(getAddr().getType());
1218 if (!resultType || symTy != resultType.getPointee())
1219 return emitOpError("result type pointee type '")
1220 << resultType.getPointee() << "' does not match type " << symTy
1221 << " of the global @" << getName();
1222
1223 return success();
1224}
1225
1226//===----------------------------------------------------------------------===//
1227// FuncOp
1228//===----------------------------------------------------------------------===//
1229
1230void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
1231 StringRef name, FuncType type) {
1232 result.addRegion();
1233 result.addAttribute(SymbolTable::getSymbolAttrName(),
1234 builder.getStringAttr(name));
1235 result.addAttribute(getFunctionTypeAttrName(result.name),
1236 TypeAttr::get(type));
1237}
1238
1239ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
1240 llvm::SMLoc loc = parser.getCurrentLocation();
1241 mlir::Builder &builder = parser.getBuilder();
1242
1243 StringAttr nameAttr;
1244 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1245 state.attributes))
1246 return failure();
1247 llvm::SmallVector<OpAsmParser::Argument, 8> arguments;
1248 llvm::SmallVector<mlir::Type> resultTypes;
1249 llvm::SmallVector<DictionaryAttr> resultAttrs;
1250 bool isVariadic = false;
1251 if (function_interface_impl::parseFunctionSignatureWithArguments(
1252 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
1253 resultAttrs))
1254 return failure();
1255 llvm::SmallVector<mlir::Type> argTypes;
1256 for (OpAsmParser::Argument &arg : arguments)
1257 argTypes.push_back(arg.type);
1258
1259 if (resultTypes.size() > 1) {
1260 return parser.emitError(
1261 loc, "functions with multiple return types are not supported");
1262 }
1263
1264 mlir::Type returnType =
1265 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
1266 : resultTypes.front());
1267
1268 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
1269 if (!fnType)
1270 return failure();
1271 state.addAttribute(getFunctionTypeAttrName(state.name),
1272 TypeAttr::get(fnType));
1273
1274 // Parse the optional function body.
1275 auto *body = state.addRegion();
1276 OptionalParseResult parseResult = parser.parseOptionalRegion(
1277 *body, arguments, /*enableNameShadowing=*/false);
1278 if (parseResult.has_value()) {
1279 if (failed(*parseResult))
1280 return failure();
1281 // Function body was parsed, make sure its not empty.
1282 if (body->empty())
1283 return parser.emitError(loc, "expected non-empty function body");
1284 }
1285
1286 return success();
1287}
1288
1289bool cir::FuncOp::isDeclaration() {
1290 // TODO(CIR): This function will actually do something once external
1291 // function declarations and aliases are upstreamed.
1292 return false;
1293}
1294
1295mlir::Region *cir::FuncOp::getCallableRegion() {
1296 // TODO(CIR): This function will have special handling for aliases and a
1297 // check for an external function, once those features have been upstreamed.
1298 return &getBody();
1299}
1300
1301void cir::FuncOp::print(OpAsmPrinter &p) {
1302 p << ' ';
1303 p.printSymbolName(getSymName());
1304 cir::FuncType fnType = getFunctionType();
1305 function_interface_impl::printFunctionSignature(
1306 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
1307
1308 // Print the body if this is not an external function.
1309 Region &body = getOperation()->getRegion(0);
1310 if (!body.empty()) {
1311 p << ' ';
1312 p.printRegion(body, /*printEntryBlockArgs=*/false,
1313 /*printBlockTerminators=*/true);
1314 }
1315}
1316
1317//===----------------------------------------------------------------------===//
1318// CIR defined traits
1319//===----------------------------------------------------------------------===//
1320
1321LogicalResult
1322mlir::OpTrait::impl::verifySameFirstOperandAndResultType(Operation *op) {
1323 if (failed(Result: verifyAtLeastNOperands(op, numOperands: 1)) || failed(Result: verifyOneResult(op)))
1324 return failure();
1325
1326 const Type type = op->getResult(idx: 0).getType();
1327 const Type opType = op->getOperand(idx: 0).getType();
1328
1329 if (type != opType)
1330 return op->emitOpError()
1331 << "requires the same type for first operand and result";
1332
1333 return success();
1334}
1335
1336// TODO(CIR): The properties of functions that require verification haven't
1337// been implemented yet.
1338mlir::LogicalResult cir::FuncOp::verify() { return success(); }
1339
1340//===----------------------------------------------------------------------===//
1341// BinOp
1342//===----------------------------------------------------------------------===//
1343LogicalResult cir::BinOp::verify() {
1344 bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
1345 bool saturated = getSaturated();
1346
1347 if (!isa<cir::IntType>(getType()) && noWrap)
1348 return emitError()
1349 << "only operations on integer values may have nsw/nuw flags";
1350
1351 bool noWrapOps = getKind() == cir::BinOpKind::Add ||
1352 getKind() == cir::BinOpKind::Sub ||
1353 getKind() == cir::BinOpKind::Mul;
1354
1355 bool saturatedOps =
1356 getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
1357
1358 if (noWrap && !noWrapOps)
1359 return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
1360 "'sub' and 'mul'";
1361 if (saturated && !saturatedOps)
1362 return emitError() << "The saturated flag is applicable to opcodes: 'add' "
1363 "and 'sub'";
1364 if (noWrap && saturated)
1365 return emitError() << "The nsw/nuw flags and the saturated flag are "
1366 "mutually exclusive";
1367
1368 assert(!cir::MissingFeatures::complexType());
1369 // TODO(cir): verify for complex binops
1370
1371 return mlir::success();
1372}
1373
1374//===----------------------------------------------------------------------===//
1375// TernaryOp
1376//===----------------------------------------------------------------------===//
1377
1378/// Given the region at `point`, or the parent operation if `point` is None,
1379/// return the successor regions. These are the regions that may be selected
1380/// during the flow of control. `operands` is a set of optional attributes that
1381/// correspond to a constant value for each operand, or null if that operand is
1382/// not a constant.
1383void cir::TernaryOp::getSuccessorRegions(
1384 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1385 // The `true` and the `false` region branch back to the parent operation.
1386 if (!point.isParent()) {
1387 regions.push_back(RegionSuccessor(this->getODSResults(0)));
1388 return;
1389 }
1390
1391 // When branching from the parent operation, both the true and false
1392 // regions are considered possible successors
1393 regions.push_back(RegionSuccessor(&getTrueRegion()));
1394 regions.push_back(RegionSuccessor(&getFalseRegion()));
1395}
1396
1397void cir::TernaryOp::build(
1398 OpBuilder &builder, OperationState &result, Value cond,
1399 function_ref<void(OpBuilder &, Location)> trueBuilder,
1400 function_ref<void(OpBuilder &, Location)> falseBuilder) {
1401 result.addOperands(cond);
1402 OpBuilder::InsertionGuard guard(builder);
1403 Region *trueRegion = result.addRegion();
1404 Block *block = builder.createBlock(trueRegion);
1405 trueBuilder(builder, result.location);
1406 Region *falseRegion = result.addRegion();
1407 builder.createBlock(falseRegion);
1408 falseBuilder(builder, result.location);
1409
1410 auto yield = dyn_cast<YieldOp>(block->getTerminator());
1411 assert((yield && yield.getNumOperands() <= 1) &&
1412 "expected zero or one result type");
1413 if (yield.getNumOperands() == 1)
1414 result.addTypes(TypeRange{yield.getOperandTypes().front()});
1415}
1416
1417//===----------------------------------------------------------------------===//
1418// SelectOp
1419//===----------------------------------------------------------------------===//
1420
1421OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
1422 mlir::Attribute condition = adaptor.getCondition();
1423 if (condition) {
1424 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
1425 return conditionValue ? getTrueValue() : getFalseValue();
1426 }
1427
1428 // cir.select if %0 then x else x -> x
1429 mlir::Attribute trueValue = adaptor.getTrueValue();
1430 mlir::Attribute falseValue = adaptor.getFalseValue();
1431 if (trueValue == falseValue)
1432 return trueValue;
1433 if (getTrueValue() == getFalseValue())
1434 return getTrueValue();
1435
1436 return {};
1437}
1438
1439//===----------------------------------------------------------------------===//
1440// ShiftOp
1441//===----------------------------------------------------------------------===//
1442LogicalResult cir::ShiftOp::verify() {
1443 mlir::Operation *op = getOperation();
1444 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
1445 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
1446 if (!op0VecTy ^ !op1VecTy)
1447 return emitOpError() << "input types cannot be one vector and one scalar";
1448
1449 if (op0VecTy) {
1450 if (op0VecTy.getSize() != op1VecTy.getSize())
1451 return emitOpError() << "input vector types must have the same size";
1452
1453 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
1454 if (!opResultTy)
1455 return emitOpError() << "the type of the result must be a vector "
1456 << "if it is vector shift";
1457
1458 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
1459 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
1460 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
1461 return emitOpError()
1462 << "vector operands do not have the same elements sizes";
1463
1464 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
1465 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
1466 return emitOpError() << "vector operands and result type do not have the "
1467 "same elements sizes";
1468 }
1469
1470 return mlir::success();
1471}
1472
1473//===----------------------------------------------------------------------===//
1474// UnaryOp
1475//===----------------------------------------------------------------------===//
1476
1477LogicalResult cir::UnaryOp::verify() {
1478 switch (getKind()) {
1479 case cir::UnaryOpKind::Inc:
1480 case cir::UnaryOpKind::Dec:
1481 case cir::UnaryOpKind::Plus:
1482 case cir::UnaryOpKind::Minus:
1483 case cir::UnaryOpKind::Not:
1484 // Nothing to verify.
1485 return success();
1486 }
1487
1488 llvm_unreachable("Unknown UnaryOp kind?");
1489}
1490
1491static bool isBoolNot(cir::UnaryOp op) {
1492 return isa<cir::BoolType>(op.getInput().getType()) &&
1493 op.getKind() == cir::UnaryOpKind::Not;
1494}
1495
1496// This folder simplifies the sequential boolean not operations.
1497// For instance, the next two unary operations will be eliminated:
1498//
1499// ```mlir
1500// %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
1501// %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
1502// ```
1503//
1504// and the argument of the first one (%0) will be used instead.
1505OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
1506 if (isBoolNot(*this))
1507 if (auto previous = dyn_cast_or_null<UnaryOp>(getInput().getDefiningOp()))
1508 if (isBoolNot(previous))
1509 return previous.getInput();
1510
1511 return {};
1512}
1513
1514//===----------------------------------------------------------------------===//
1515// GetMemberOp Definitions
1516//===----------------------------------------------------------------------===//
1517
1518LogicalResult cir::GetMemberOp::verify() {
1519 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
1520 if (!recordTy)
1521 return emitError() << "expected pointer to a record type";
1522
1523 if (recordTy.getMembers().size() <= getIndex())
1524 return emitError() << "member index out of bounds";
1525
1526 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
1527 return emitError() << "member type mismatch";
1528
1529 return mlir::success();
1530}
1531
1532//===----------------------------------------------------------------------===//
1533// VecCreateOp
1534//===----------------------------------------------------------------------===//
1535
1536LogicalResult cir::VecCreateOp::verify() {
1537 // Verify that the number of arguments matches the number of elements in the
1538 // vector, and that the type of all the arguments matches the type of the
1539 // elements in the vector.
1540 const cir::VectorType vecTy = getType();
1541 if (getElements().size() != vecTy.getSize()) {
1542 return emitOpError() << "operand count of " << getElements().size()
1543 << " doesn't match vector type " << vecTy
1544 << " element count of " << vecTy.getSize();
1545 }
1546
1547 const mlir::Type elementType = vecTy.getElementType();
1548 for (const mlir::Value element : getElements()) {
1549 if (element.getType() != elementType) {
1550 return emitOpError() << "operand type " << element.getType()
1551 << " doesn't match vector element type "
1552 << elementType;
1553 }
1554 }
1555
1556 return success();
1557}
1558
1559//===----------------------------------------------------------------------===//
1560// VecExtractOp
1561//===----------------------------------------------------------------------===//
1562
1563OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
1564 const auto vectorAttr =
1565 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
1566 if (!vectorAttr)
1567 return {};
1568
1569 const auto indexAttr =
1570 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
1571 if (!indexAttr)
1572 return {};
1573
1574 const mlir::ArrayAttr elements = vectorAttr.getElts();
1575 const uint64_t index = indexAttr.getUInt();
1576 if (index >= elements.size())
1577 return {};
1578
1579 return elements[index];
1580}
1581
1582//===----------------------------------------------------------------------===//
1583// VecShuffleOp
1584//===----------------------------------------------------------------------===//
1585
1586OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
1587 auto vec1Attr =
1588 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
1589 auto vec2Attr =
1590 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
1591 if (!vec1Attr || !vec2Attr)
1592 return {};
1593
1594 mlir::Type vec1ElemTy =
1595 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
1596
1597 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
1598 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
1599 mlir::ArrayAttr indicesElts = adaptor.getIndices();
1600
1601 SmallVector<mlir::Attribute, 16> elements;
1602 elements.reserve(indicesElts.size());
1603
1604 uint64_t vec1Size = vec1Elts.size();
1605 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
1606 if (idxAttr.getSInt() == -1) {
1607 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
1608 continue;
1609 }
1610
1611 uint64_t idxValue = idxAttr.getUInt();
1612 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
1613 : vec2Elts[idxValue - vec1Size]);
1614 }
1615
1616 return cir::ConstVectorAttr::get(
1617 getType(), mlir::ArrayAttr::get(getContext(), elements));
1618}
1619
1620LogicalResult cir::VecShuffleOp::verify() {
1621 // The number of elements in the indices array must match the number of
1622 // elements in the result type.
1623 if (getIndices().size() != getResult().getType().getSize()) {
1624 return emitOpError() << ": the number of elements in " << getIndices()
1625 << " and " << getResult().getType() << " don't match";
1626 }
1627
1628 // The element types of the two input vectors and of the result type must
1629 // match.
1630 if (getVec1().getType().getElementType() !=
1631 getResult().getType().getElementType()) {
1632 return emitOpError() << ": element types of " << getVec1().getType()
1633 << " and " << getResult().getType() << " don't match";
1634 }
1635
1636 return success();
1637}
1638
1639//===----------------------------------------------------------------------===//
1640// VecShuffleDynamicOp
1641//===----------------------------------------------------------------------===//
1642
1643OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
1644 mlir::Attribute vec = adaptor.getVec();
1645 mlir::Attribute indices = adaptor.getIndices();
1646 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
1647 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
1648 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
1649 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
1650
1651 mlir::ArrayAttr vecElts = vecAttr.getElts();
1652 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
1653
1654 const uint64_t numElements = vecElts.size();
1655
1656 SmallVector<mlir::Attribute, 16> elements;
1657 elements.reserve(numElements);
1658
1659 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
1660 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
1661 uint64_t idxValue = idxAttr.getUInt();
1662 uint64_t newIdx = idxValue & maskBits;
1663 elements.push_back(vecElts[newIdx]);
1664 }
1665
1666 return cir::ConstVectorAttr::get(
1667 getType(), mlir::ArrayAttr::get(getContext(), elements));
1668 }
1669
1670 return {};
1671}
1672
1673LogicalResult cir::VecShuffleDynamicOp::verify() {
1674 // The number of elements in the two input vectors must match.
1675 if (getVec().getType().getSize() !=
1676 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
1677 return emitOpError() << ": the number of elements in " << getVec().getType()
1678 << " and " << getIndices().getType() << " don't match";
1679 }
1680 return success();
1681}
1682
1683//===----------------------------------------------------------------------===//
1684// VecTernaryOp
1685//===----------------------------------------------------------------------===//
1686
1687LogicalResult cir::VecTernaryOp::verify() {
1688 // Verify that the condition operand has the same number of elements as the
1689 // other operands. (The automatic verification already checked that all
1690 // operands are vector types and that the second and third operands are the
1691 // same type.)
1692 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
1693 return emitOpError() << ": the number of elements in "
1694 << getCond().getType() << " and " << getLhs().getType()
1695 << " don't match";
1696 }
1697 return success();
1698}
1699
1700OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
1701 mlir::Attribute cond = adaptor.getCond();
1702 mlir::Attribute lhs = adaptor.getLhs();
1703 mlir::Attribute rhs = adaptor.getRhs();
1704
1705 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
1706 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
1707 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
1708 return {};
1709 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
1710 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
1711 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
1712
1713 mlir::ArrayAttr condElts = condVec.getElts();
1714
1715 SmallVector<mlir::Attribute, 16> elements;
1716 elements.reserve(condElts.size());
1717
1718 for (const auto &[idx, condAttr] :
1719 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
1720 if (condAttr.getSInt()) {
1721 elements.push_back(lhsVec.getElts()[idx]);
1722 } else {
1723 elements.push_back(rhsVec.getElts()[idx]);
1724 }
1725 }
1726
1727 cir::VectorType vecTy = getLhs().getType();
1728 return cir::ConstVectorAttr::get(
1729 vecTy, mlir::ArrayAttr::get(getContext(), elements));
1730}
1731
1732//===----------------------------------------------------------------------===//
1733// TableGen'd op method definitions
1734//===----------------------------------------------------------------------===//
1735
1736#define GET_OP_CLASSES
1737#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
1738

source code of clang/lib/CIR/Dialect/IR/CIRDialect.cpp