1//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the CIR dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "clang/CIR/Dialect/IR/CIRDialect.h"
14
15#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
16#include "clang/CIR/Dialect/IR/CIRTypes.h"
17
18#include "mlir/Interfaces/ControlFlowInterfaces.h"
19#include "mlir/Interfaces/FunctionImplementation.h"
20
21#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
22#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
23#include "clang/CIR/MissingFeatures.h"
24#include "llvm/Support/LogicalResult.h"
25
26#include <numeric>
27
28using namespace mlir;
29using namespace cir;
30
31//===----------------------------------------------------------------------===//
32// CIR Dialect
33//===----------------------------------------------------------------------===//
34namespace {
35struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
36 using OpAsmDialectInterface::OpAsmDialectInterface;
37
38 AliasResult getAlias(Type type, raw_ostream &os) const final {
39 if (auto recordType = dyn_cast<cir::RecordType>(type)) {
40 StringAttr nameAttr = recordType.getName();
41 if (!nameAttr)
42 os << "rec_anon_" << recordType.getKindAsStr();
43 else
44 os << "rec_" << nameAttr.getValue();
45 return AliasResult::OverridableAlias;
46 }
47 if (auto intType = dyn_cast<cir::IntType>(type)) {
48 // We only provide alias for standard integer types (i.e. integer types
49 // whose width is a power of 2 and at least 8).
50 unsigned width = intType.getWidth();
51 if (width < 8 || !llvm::isPowerOf2_32(Value: width))
52 return AliasResult::NoAlias;
53 os << intType.getAlias();
54 return AliasResult::OverridableAlias;
55 }
56 if (auto voidType = dyn_cast<cir::VoidType>(type)) {
57 os << voidType.getAlias();
58 return AliasResult::OverridableAlias;
59 }
60
61 return AliasResult::NoAlias;
62 }
63
64 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
65 if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
66 os << (boolAttr.getValue() ? "true" : "false");
67 return AliasResult::FinalAlias;
68 }
69 if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
70 os << "bfi_" << bitfield.getName().str();
71 return AliasResult::FinalAlias;
72 }
73 return AliasResult::NoAlias;
74 }
75};
76} // namespace
77
78void cir::CIRDialect::initialize() {
79 registerTypes();
80 registerAttributes();
81 addOperations<
82#define GET_OP_LIST
83#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
84 >();
85 addInterfaces<CIROpAsmDialectInterface>();
86}
87
88Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
89 mlir::Attribute value,
90 mlir::Type type,
91 mlir::Location loc) {
92 return builder.create<cir::ConstantOp>(loc, type,
93 mlir::cast<mlir::TypedAttr>(value));
94}
95
96//===----------------------------------------------------------------------===//
97// Helpers
98//===----------------------------------------------------------------------===//
99
100// Parses one of the keywords provided in the list `keywords` and returns the
101// position of the parsed keyword in the list. If none of the keywords from the
102// list is parsed, returns -1.
103static int parseOptionalKeywordAlternative(AsmParser &parser,
104 ArrayRef<llvm::StringRef> keywords) {
105 for (auto en : llvm::enumerate(keywords)) {
106 if (succeeded(parser.parseOptionalKeyword(en.value())))
107 return en.index();
108 }
109 return -1;
110}
111
112namespace {
113template <typename Ty> struct EnumTraits {};
114
115#define REGISTER_ENUM_TYPE(Ty) \
116 template <> struct EnumTraits<cir::Ty> { \
117 static llvm::StringRef stringify(cir::Ty value) { \
118 return stringify##Ty(value); \
119 } \
120 static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); } \
121 }
122
123REGISTER_ENUM_TYPE(GlobalLinkageKind);
124REGISTER_ENUM_TYPE(VisibilityKind);
125REGISTER_ENUM_TYPE(SideEffect);
126} // namespace
127
128/// Parse an enum from the keyword, or default to the provided default value.
129/// The return type is the enum type by default, unless overriden with the
130/// second template argument.
131template <typename EnumTy, typename RetTy = EnumTy>
132static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
133 llvm::SmallVector<llvm::StringRef, 10> names;
134 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
135 names.push_back(Elt: EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
136
137 int index = parseOptionalKeywordAlternative(parser, keywords: names);
138 if (index == -1)
139 return static_cast<RetTy>(defaultValue);
140 return static_cast<RetTy>(index);
141}
142
143/// Parse an enum from the keyword, return failure if the keyword is not found.
144template <typename EnumTy, typename RetTy = EnumTy>
145static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
146 llvm::SmallVector<llvm::StringRef, 10> names;
147 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
148 names.push_back(Elt: EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
149
150 int index = parseOptionalKeywordAlternative(parser, keywords: names);
151 if (index == -1)
152 return failure();
153 result = static_cast<RetTy>(index);
154 return success();
155}
156
157// Check if a region's termination omission is valid and, if so, creates and
158// inserts the omitted terminator into the region.
159static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
160 SMLoc errLoc) {
161 Location eLoc = parser.getEncodedSourceLoc(loc: parser.getCurrentLocation());
162 OpBuilder builder(parser.getBuilder().getContext());
163
164 // Insert empty block in case the region is empty to ensure the terminator
165 // will be inserted
166 if (region.empty())
167 builder.createBlock(parent: &region);
168
169 Block &block = region.back();
170 // Region is properly terminated: nothing to do.
171 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
172 return success();
173
174 // Check for invalid terminator omissions.
175 if (!region.hasOneBlock())
176 return parser.emitError(loc: errLoc,
177 message: "multi-block region must not omit terminator");
178
179 // Terminator was omitted correctly: recreate it.
180 builder.setInsertionPointToEnd(&block);
181 builder.create<cir::YieldOp>(eLoc);
182 return success();
183}
184
185// True if the region's terminator should be omitted.
186static bool omitRegionTerm(mlir::Region &r) {
187 const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
188 const auto yieldsNothing = [&r]() {
189 auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
190 return y && y.getArgs().empty();
191 };
192 return singleNonEmptyBlock && yieldsNothing();
193}
194
195void printVisibilityAttr(OpAsmPrinter &printer,
196 cir::VisibilityAttr &visibility) {
197 switch (visibility.getValue()) {
198 case cir::VisibilityKind::Hidden:
199 printer << "hidden";
200 break;
201 case cir::VisibilityKind::Protected:
202 printer << "protected";
203 break;
204 case cir::VisibilityKind::Default:
205 break;
206 }
207}
208
209void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) {
210 cir::VisibilityKind visibilityKind =
211 parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
212 visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind);
213}
214
215//===----------------------------------------------------------------------===//
216// CIR Custom Parsers/Printers
217//===----------------------------------------------------------------------===//
218
219static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
220 mlir::Region &region) {
221 auto regionLoc = parser.getCurrentLocation();
222 if (parser.parseRegion(region))
223 return failure();
224 if (ensureRegionTerm(parser, region, errLoc: regionLoc).failed())
225 return failure();
226 return success();
227}
228
229static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
230 cir::ScopeOp &op,
231 mlir::Region &region) {
232 printer.printRegion(blocks&: region,
233 /*printEntryBlockArgs=*/false,
234 /*printBlockTerminators=*/!omitRegionTerm(r&: region));
235}
236
237//===----------------------------------------------------------------------===//
238// AllocaOp
239//===----------------------------------------------------------------------===//
240
241void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
242 mlir::OperationState &odsState, mlir::Type addr,
243 mlir::Type allocaType, llvm::StringRef name,
244 mlir::IntegerAttr alignment) {
245 odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
246 mlir::TypeAttr::get(allocaType));
247 odsState.addAttribute(getNameAttrName(odsState.name),
248 odsBuilder.getStringAttr(name));
249 if (alignment) {
250 odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
251 }
252 odsState.addTypes(addr);
253}
254
255//===----------------------------------------------------------------------===//
256// BreakOp
257//===----------------------------------------------------------------------===//
258
259LogicalResult cir::BreakOp::verify() {
260 assert(!cir::MissingFeatures::switchOp());
261 if (!getOperation()->getParentOfType<LoopOpInterface>() &&
262 !getOperation()->getParentOfType<SwitchOp>())
263 return emitOpError("must be within a loop");
264 return success();
265}
266
267//===----------------------------------------------------------------------===//
268// ConditionOp
269//===----------------------------------------------------------------------===//
270
271//===----------------------------------
272// BranchOpTerminatorInterface Methods
273//===----------------------------------
274
275void cir::ConditionOp::getSuccessorRegions(
276 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
277 // TODO(cir): The condition value may be folded to a constant, narrowing
278 // down its list of possible successors.
279
280 // Parent is a loop: condition may branch to the body or to the parent op.
281 if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
282 regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
283 regions.emplace_back(loopOp->getResults());
284 }
285
286 assert(!cir::MissingFeatures::awaitOp());
287}
288
289MutableOperandRange
290cir::ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
291 // No values are yielded to the successor region.
292 return MutableOperandRange(getOperation(), 0, 0);
293}
294
295LogicalResult cir::ConditionOp::verify() {
296 assert(!cir::MissingFeatures::awaitOp());
297 if (!isa<LoopOpInterface>(getOperation()->getParentOp()))
298 return emitOpError("condition must be within a conditional region");
299 return success();
300}
301
302//===----------------------------------------------------------------------===//
303// ConstantOp
304//===----------------------------------------------------------------------===//
305
306static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
307 mlir::Attribute attrType) {
308 if (isa<cir::ConstPtrAttr>(attrType)) {
309 if (!mlir::isa<cir::PointerType>(opType))
310 return op->emitOpError(
311 message: "pointer constant initializing a non-pointer type");
312 return success();
313 }
314
315 if (isa<cir::ZeroAttr>(attrType)) {
316 if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
317 opType))
318 return success();
319 return op->emitOpError(
320 message: "zero expects struct, array, vector, or complex type");
321 }
322
323 if (mlir::isa<cir::BoolAttr>(attrType)) {
324 if (!mlir::isa<cir::BoolType>(opType))
325 return op->emitOpError(message: "result type (")
326 << opType << ") must be '!cir.bool' for '" << attrType << "'";
327 return success();
328 }
329
330 if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
331 auto at = cast<TypedAttr>(Val&: attrType);
332 if (at.getType() != opType) {
333 return op->emitOpError(message: "result type (")
334 << opType << ") does not match value type (" << at.getType()
335 << ")";
336 }
337 return success();
338 }
339
340 if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
341 cir::ConstComplexAttr>(attrType))
342 return success();
343
344 assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
345 return op->emitOpError(message: "global with type ")
346 << cast<TypedAttr>(Val&: attrType).getType() << " not yet supported";
347}
348
349LogicalResult cir::ConstantOp::verify() {
350 // ODS already generates checks to make sure the result type is valid. We just
351 // need to additionally check that the value's attribute type is consistent
352 // with the result type.
353 return checkConstantTypes(getOperation(), getType(), getValue());
354}
355
356OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
357 return getValue();
358}
359
360//===----------------------------------------------------------------------===//
361// ContinueOp
362//===----------------------------------------------------------------------===//
363
364LogicalResult cir::ContinueOp::verify() {
365 if (!getOperation()->getParentOfType<LoopOpInterface>())
366 return emitOpError("must be within a loop");
367 return success();
368}
369
370//===----------------------------------------------------------------------===//
371// CastOp
372//===----------------------------------------------------------------------===//
373
374LogicalResult cir::CastOp::verify() {
375 mlir::Type resType = getType();
376 mlir::Type srcType = getSrc().getType();
377
378 if (mlir::isa<cir::VectorType>(srcType) &&
379 mlir::isa<cir::VectorType>(resType)) {
380 // Use the element type of the vector to verify the cast kind. (Except for
381 // bitcast, see below.)
382 srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
383 resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
384 }
385
386 switch (getKind()) {
387 case cir::CastKind::int_to_bool: {
388 if (!mlir::isa<cir::BoolType>(resType))
389 return emitOpError() << "requires !cir.bool type for result";
390 if (!mlir::isa<cir::IntType>(srcType))
391 return emitOpError() << "requires !cir.int type for source";
392 return success();
393 }
394 case cir::CastKind::ptr_to_bool: {
395 if (!mlir::isa<cir::BoolType>(resType))
396 return emitOpError() << "requires !cir.bool type for result";
397 if (!mlir::isa<cir::PointerType>(srcType))
398 return emitOpError() << "requires !cir.ptr type for source";
399 return success();
400 }
401 case cir::CastKind::integral: {
402 if (!mlir::isa<cir::IntType>(resType))
403 return emitOpError() << "requires !cir.int type for result";
404 if (!mlir::isa<cir::IntType>(srcType))
405 return emitOpError() << "requires !cir.int type for source";
406 return success();
407 }
408 case cir::CastKind::array_to_ptrdecay: {
409 const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
410 const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
411 if (!arrayPtrTy || !flatPtrTy)
412 return emitOpError() << "requires !cir.ptr type for source and result";
413
414 // TODO(CIR): Make sure the AddrSpace of both types are equals
415 return success();
416 }
417 case cir::CastKind::bitcast: {
418 // Handle the pointer types first.
419 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
420 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
421
422 if (srcPtrTy && resPtrTy) {
423 return success();
424 }
425
426 return success();
427 }
428 case cir::CastKind::floating: {
429 if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
430 !mlir::isa<cir::FPTypeInterface>(resType))
431 return emitOpError() << "requires !cir.float type for source and result";
432 return success();
433 }
434 case cir::CastKind::float_to_int: {
435 if (!mlir::isa<cir::FPTypeInterface>(srcType))
436 return emitOpError() << "requires !cir.float type for source";
437 if (!mlir::dyn_cast<cir::IntType>(resType))
438 return emitOpError() << "requires !cir.int type for result";
439 return success();
440 }
441 case cir::CastKind::int_to_ptr: {
442 if (!mlir::dyn_cast<cir::IntType>(srcType))
443 return emitOpError() << "requires !cir.int type for source";
444 if (!mlir::dyn_cast<cir::PointerType>(resType))
445 return emitOpError() << "requires !cir.ptr type for result";
446 return success();
447 }
448 case cir::CastKind::ptr_to_int: {
449 if (!mlir::dyn_cast<cir::PointerType>(srcType))
450 return emitOpError() << "requires !cir.ptr type for source";
451 if (!mlir::dyn_cast<cir::IntType>(resType))
452 return emitOpError() << "requires !cir.int type for result";
453 return success();
454 }
455 case cir::CastKind::float_to_bool: {
456 if (!mlir::isa<cir::FPTypeInterface>(srcType))
457 return emitOpError() << "requires !cir.float type for source";
458 if (!mlir::isa<cir::BoolType>(resType))
459 return emitOpError() << "requires !cir.bool type for result";
460 return success();
461 }
462 case cir::CastKind::bool_to_int: {
463 if (!mlir::isa<cir::BoolType>(srcType))
464 return emitOpError() << "requires !cir.bool type for source";
465 if (!mlir::isa<cir::IntType>(resType))
466 return emitOpError() << "requires !cir.int type for result";
467 return success();
468 }
469 case cir::CastKind::int_to_float: {
470 if (!mlir::isa<cir::IntType>(srcType))
471 return emitOpError() << "requires !cir.int type for source";
472 if (!mlir::isa<cir::FPTypeInterface>(resType))
473 return emitOpError() << "requires !cir.float type for result";
474 return success();
475 }
476 case cir::CastKind::bool_to_float: {
477 if (!mlir::isa<cir::BoolType>(srcType))
478 return emitOpError() << "requires !cir.bool type for source";
479 if (!mlir::isa<cir::FPTypeInterface>(resType))
480 return emitOpError() << "requires !cir.float type for result";
481 return success();
482 }
483 case cir::CastKind::address_space: {
484 auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
485 auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
486 if (!srcPtrTy || !resPtrTy)
487 return emitOpError() << "requires !cir.ptr type for source and result";
488 if (srcPtrTy.getPointee() != resPtrTy.getPointee())
489 return emitOpError() << "requires two types differ in addrspace only";
490 return success();
491 }
492 default:
493 llvm_unreachable("Unknown CastOp kind?");
494 }
495}
496
497static bool isIntOrBoolCast(cir::CastOp op) {
498 auto kind = op.getKind();
499 return kind == cir::CastKind::bool_to_int ||
500 kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
501}
502
503static Value tryFoldCastChain(cir::CastOp op) {
504 cir::CastOp head = op, tail = op;
505
506 while (op) {
507 if (!isIntOrBoolCast(op))
508 break;
509 head = op;
510 op = dyn_cast_or_null<cir::CastOp>(head.getSrc().getDefiningOp());
511 }
512
513 if (head == tail)
514 return {};
515
516 // if bool_to_int -> ... -> int_to_bool: take the bool
517 // as we had it was before all casts
518 if (head.getKind() == cir::CastKind::bool_to_int &&
519 tail.getKind() == cir::CastKind::int_to_bool)
520 return head.getSrc();
521
522 // if int_to_bool -> ... -> int_to_bool: take the result
523 // of the first one, as no other casts (and ext casts as well)
524 // don't change the first result
525 if (head.getKind() == cir::CastKind::int_to_bool &&
526 tail.getKind() == cir::CastKind::int_to_bool)
527 return head.getResult();
528
529 return {};
530}
531
532OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
533 if (getSrc().getType() == getType()) {
534 switch (getKind()) {
535 case cir::CastKind::integral: {
536 // TODO: for sign differences, it's possible in certain conditions to
537 // create a new attribute that's capable of representing the source.
538 llvm::SmallVector<mlir::OpFoldResult, 1> foldResults;
539 auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
540 if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
541 return mlir::cast<mlir::Attribute>(foldResults[0]);
542 return {};
543 }
544 case cir::CastKind::bitcast:
545 case cir::CastKind::address_space:
546 case cir::CastKind::float_complex:
547 case cir::CastKind::int_complex: {
548 return getSrc();
549 }
550 default:
551 return {};
552 }
553 }
554 return tryFoldCastChain(*this);
555}
556
557//===----------------------------------------------------------------------===//
558// CallOp
559//===----------------------------------------------------------------------===//
560
561mlir::OperandRange cir::CallOp::getArgOperands() {
562 if (isIndirect())
563 return getArgs().drop_front(1);
564 return getArgs();
565}
566
567mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
568 mlir::MutableOperandRange args = getArgsMutable();
569 if (isIndirect())
570 return args.slice(1, args.size() - 1);
571 return args;
572}
573
574mlir::Value cir::CallOp::getIndirectCall() {
575 assert(isIndirect());
576 return getOperand(0);
577}
578
579/// Return the operand at index 'i'.
580Value cir::CallOp::getArgOperand(unsigned i) {
581 if (isIndirect())
582 ++i;
583 return getOperand(i);
584}
585
586/// Return the number of operands.
587unsigned cir::CallOp::getNumArgOperands() {
588 if (isIndirect())
589 return this->getOperation()->getNumOperands() - 1;
590 return this->getOperation()->getNumOperands();
591}
592
593static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
594 mlir::OperationState &result) {
595 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
596 llvm::SMLoc opsLoc;
597 mlir::FlatSymbolRefAttr calleeAttr;
598 llvm::ArrayRef<mlir::Type> allResultTypes;
599
600 // If we cannot parse a string callee, it means this is an indirect call.
601 if (!parser
602 .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
603 result.attributes)
604 .has_value()) {
605 OpAsmParser::UnresolvedOperand indirectVal;
606 // Do not resolve right now, since we need to figure out the type
607 if (parser.parseOperand(result&: indirectVal).failed())
608 return failure();
609 ops.push_back(Elt: indirectVal);
610 }
611
612 if (parser.parseLParen())
613 return mlir::failure();
614
615 opsLoc = parser.getCurrentLocation();
616 if (parser.parseOperandList(result&: ops))
617 return mlir::failure();
618 if (parser.parseRParen())
619 return mlir::failure();
620
621 if (parser.parseOptionalKeyword("nothrow").succeeded())
622 result.addAttribute(CIRDialect::getNoThrowAttrName(),
623 mlir::UnitAttr::get(parser.getContext()));
624
625 if (parser.parseOptionalKeyword(keyword: "side_effect").succeeded()) {
626 if (parser.parseLParen().failed())
627 return failure();
628 cir::SideEffect sideEffect;
629 if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
630 return failure();
631 if (parser.parseRParen().failed())
632 return failure();
633 auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
634 result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
635 }
636
637 if (parser.parseOptionalAttrDict(result&: result.attributes))
638 return ::mlir::failure();
639
640 if (parser.parseColon())
641 return ::mlir::failure();
642
643 mlir::FunctionType opsFnTy;
644 if (parser.parseType(result&: opsFnTy))
645 return mlir::failure();
646
647 allResultTypes = opsFnTy.getResults();
648 result.addTypes(newTypes: allResultTypes);
649
650 if (parser.resolveOperands(operands&: ops, types: opsFnTy.getInputs(), loc: opsLoc, result&: result.operands))
651 return mlir::failure();
652
653 return mlir::success();
654}
655
656static void printCallCommon(mlir::Operation *op,
657 mlir::FlatSymbolRefAttr calleeSym,
658 mlir::Value indirectCallee,
659 mlir::OpAsmPrinter &printer, bool isNothrow,
660 cir::SideEffect sideEffect) {
661 printer << ' ';
662
663 auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
664 auto ops = callLikeOp.getArgOperands();
665
666 if (calleeSym) {
667 // Direct calls
668 printer.printAttributeWithoutType(attr: calleeSym);
669 } else {
670 // Indirect calls
671 assert(indirectCallee);
672 printer << indirectCallee;
673 }
674 printer << "(" << ops << ")";
675
676 if (isNothrow)
677 printer << " nothrow";
678
679 if (sideEffect != cir::SideEffect::All) {
680 printer << " side_effect(";
681 printer << stringifySideEffect(sideEffect);
682 printer << ")";
683 }
684
685 printer.printOptionalAttrDict(op->getAttrs(),
686 {CIRDialect::getCalleeAttrName(),
687 CIRDialect::getNoThrowAttrName(),
688 CIRDialect::getSideEffectAttrName()});
689
690 printer << " : ";
691 printer.printFunctionalType(inputs: op->getOperands().getTypes(),
692 results: op->getResultTypes());
693}
694
695mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
696 mlir::OperationState &result) {
697 return parseCallCommon(parser, result);
698}
699
700void cir::CallOp::print(mlir::OpAsmPrinter &p) {
701 mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
702 cir::SideEffect sideEffect = getSideEffect();
703 printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
704 sideEffect);
705}
706
707static LogicalResult
708verifyCallCommInSymbolUses(mlir::Operation *op,
709 SymbolTableCollection &symbolTable) {
710 auto fnAttr =
711 op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
712 if (!fnAttr) {
713 // This is an indirect call, thus we don't have to check the symbol uses.
714 return mlir::success();
715 }
716
717 auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
718 if (!fn)
719 return op->emitOpError() << "'" << fnAttr.getValue()
720 << "' does not reference a valid function";
721
722 auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
723 assert(callIf && "expected CIR call interface to be always available");
724
725 // Verify that the operand and result types match the callee. Note that
726 // argument-checking is disabled for functions without a prototype.
727 auto fnType = fn.getFunctionType();
728 if (!fn.getNoProto()) {
729 unsigned numCallOperands = callIf.getNumArgOperands();
730 unsigned numFnOpOperands = fnType.getNumInputs();
731
732 if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
733 return op->emitOpError(message: "incorrect number of operands for callee");
734 if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
735 return op->emitOpError(message: "too few operands for callee");
736
737 for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
738 if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
739 return op->emitOpError(message: "operand type mismatch: expected operand type ")
740 << fnType.getInput(i) << ", but provided "
741 << op->getOperand(idx: i).getType() << " for operand number " << i;
742 }
743
744 assert(!cir::MissingFeatures::opCallCallConv());
745
746 // Void function must not return any results.
747 if (fnType.hasVoidReturn() && op->getNumResults() != 0)
748 return op->emitOpError(message: "callee returns void but call has results");
749
750 // Non-void function calls must return exactly one result.
751 if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
752 return op->emitOpError(message: "incorrect number of results for callee");
753
754 // Parent function and return value types must match.
755 if (!fnType.hasVoidReturn() &&
756 op->getResultTypes().front() != fnType.getReturnType()) {
757 return op->emitOpError(message: "result type mismatch: expected ")
758 << fnType.getReturnType() << ", but provided "
759 << op->getResult(idx: 0).getType();
760 }
761
762 return mlir::success();
763}
764
765LogicalResult
766cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
767 return verifyCallCommInSymbolUses(*this, symbolTable);
768}
769
770//===----------------------------------------------------------------------===//
771// ReturnOp
772//===----------------------------------------------------------------------===//
773
774static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
775 cir::FuncOp function) {
776 // ReturnOps currently only have a single optional operand.
777 if (op.getNumOperands() > 1)
778 return op.emitOpError() << "expects at most 1 return operand";
779
780 // Ensure returned type matches the function signature.
781 auto expectedTy = function.getFunctionType().getReturnType();
782 auto actualTy =
783 (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
784 : op.getOperand(0).getType());
785 if (actualTy != expectedTy)
786 return op.emitOpError() << "returns " << actualTy
787 << " but enclosing function returns " << expectedTy;
788
789 return mlir::success();
790}
791
792mlir::LogicalResult cir::ReturnOp::verify() {
793 // Returns can be present in multiple different scopes, get the
794 // wrapping function and start from there.
795 auto *fnOp = getOperation()->getParentOp();
796 while (!isa<cir::FuncOp>(fnOp))
797 fnOp = fnOp->getParentOp();
798
799 // Make sure return types match function return type.
800 if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
801 return failure();
802
803 return success();
804}
805
806//===----------------------------------------------------------------------===//
807// IfOp
808//===----------------------------------------------------------------------===//
809
810ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
811 // create the regions for 'then'.
812 result.regions.reserve(2);
813 Region *thenRegion = result.addRegion();
814 Region *elseRegion = result.addRegion();
815
816 mlir::Builder &builder = parser.getBuilder();
817 OpAsmParser::UnresolvedOperand cond;
818 Type boolType = cir::BoolType::get(builder.getContext());
819
820 if (parser.parseOperand(cond) ||
821 parser.resolveOperand(cond, boolType, result.operands))
822 return failure();
823
824 // Parse 'then' region.
825 mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
826 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
827 return failure();
828
829 if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
830 return failure();
831
832 // If we find an 'else' keyword, parse the 'else' region.
833 if (!parser.parseOptionalKeyword("else")) {
834 mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
835 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
836 return failure();
837 if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
838 return failure();
839 }
840
841 // Parse the optional attribute list.
842 if (parser.parseOptionalAttrDict(result.attributes))
843 return failure();
844 return success();
845}
846
847void cir::IfOp::print(OpAsmPrinter &p) {
848 p << " " << getCondition() << " ";
849 mlir::Region &thenRegion = this->getThenRegion();
850 p.printRegion(thenRegion,
851 /*printEntryBlockArgs=*/false,
852 /*printBlockTerminators=*/!omitRegionTerm(thenRegion));
853
854 // Print the 'else' regions if it exists and has a block.
855 mlir::Region &elseRegion = this->getElseRegion();
856 if (!elseRegion.empty()) {
857 p << " else ";
858 p.printRegion(elseRegion,
859 /*printEntryBlockArgs=*/false,
860 /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
861 }
862
863 p.printOptionalAttrDict(getOperation()->getAttrs());
864}
865
866/// Default callback for IfOp builders.
867void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
868 // add cir.yield to end of the block
869 builder.create<cir::YieldOp>(loc);
870}
871
872/// Given the region at `index`, or the parent operation if `index` is None,
873/// return the successor regions. These are the regions that may be selected
874/// during the flow of control. `operands` is a set of optional attributes that
875/// correspond to a constant value for each operand, or null if that operand is
876/// not a constant.
877void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
878 SmallVectorImpl<RegionSuccessor> &regions) {
879 // The `then` and the `else` region branch back to the parent operation.
880 if (!point.isParent()) {
881 regions.push_back(RegionSuccessor());
882 return;
883 }
884
885 // Don't consider the else region if it is empty.
886 Region *elseRegion = &this->getElseRegion();
887 if (elseRegion->empty())
888 elseRegion = nullptr;
889
890 // If the condition isn't constant, both regions may be executed.
891 regions.push_back(RegionSuccessor(&getThenRegion()));
892 // If the else region does not exist, it is not a viable successor.
893 if (elseRegion)
894 regions.push_back(RegionSuccessor(elseRegion));
895
896 return;
897}
898
899void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
900 bool withElseRegion, BuilderCallbackRef thenBuilder,
901 BuilderCallbackRef elseBuilder) {
902 assert(thenBuilder && "the builder callback for 'then' must be present");
903 result.addOperands(cond);
904
905 OpBuilder::InsertionGuard guard(builder);
906 Region *thenRegion = result.addRegion();
907 builder.createBlock(thenRegion);
908 thenBuilder(builder, result.location);
909
910 Region *elseRegion = result.addRegion();
911 if (!withElseRegion)
912 return;
913
914 builder.createBlock(elseRegion);
915 elseBuilder(builder, result.location);
916}
917
918//===----------------------------------------------------------------------===//
919// ScopeOp
920//===----------------------------------------------------------------------===//
921
922/// Given the region at `index`, or the parent operation if `index` is None,
923/// return the successor regions. These are the regions that may be selected
924/// during the flow of control. `operands` is a set of optional attributes
925/// that correspond to a constant value for each operand, or null if that
926/// operand is not a constant.
927void cir::ScopeOp::getSuccessorRegions(
928 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
929 // The only region always branch back to the parent operation.
930 if (!point.isParent()) {
931 regions.push_back(RegionSuccessor(getODSResults(0)));
932 return;
933 }
934
935 // If the condition isn't constant, both regions may be executed.
936 regions.push_back(RegionSuccessor(&getScopeRegion()));
937}
938
939void cir::ScopeOp::build(
940 OpBuilder &builder, OperationState &result,
941 function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
942 assert(scopeBuilder && "the builder callback for 'then' must be present");
943
944 OpBuilder::InsertionGuard guard(builder);
945 Region *scopeRegion = result.addRegion();
946 builder.createBlock(scopeRegion);
947 assert(!cir::MissingFeatures::opScopeCleanupRegion());
948
949 mlir::Type yieldTy;
950 scopeBuilder(builder, yieldTy, result.location);
951
952 if (yieldTy)
953 result.addTypes(TypeRange{yieldTy});
954}
955
956void cir::ScopeOp::build(
957 OpBuilder &builder, OperationState &result,
958 function_ref<void(OpBuilder &, Location)> scopeBuilder) {
959 assert(scopeBuilder && "the builder callback for 'then' must be present");
960 OpBuilder::InsertionGuard guard(builder);
961 Region *scopeRegion = result.addRegion();
962 builder.createBlock(scopeRegion);
963 assert(!cir::MissingFeatures::opScopeCleanupRegion());
964 scopeBuilder(builder, result.location);
965}
966
967LogicalResult cir::ScopeOp::verify() {
968 if (getRegion().empty()) {
969 return emitOpError() << "cir.scope must not be empty since it should "
970 "include at least an implicit cir.yield ";
971 }
972
973 mlir::Block &lastBlock = getRegion().back();
974 if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
975 !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
976 return emitOpError() << "last block of cir.scope must be terminated";
977 return success();
978}
979
980//===----------------------------------------------------------------------===//
981// BrOp
982//===----------------------------------------------------------------------===//
983
984mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
985 assert(index == 0 && "invalid successor index");
986 return mlir::SuccessorOperands(getDestOperandsMutable());
987}
988
989Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
990 return getDest();
991}
992
993//===----------------------------------------------------------------------===//
994// BrCondOp
995//===----------------------------------------------------------------------===//
996
997mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
998 assert(index < getNumSuccessors() && "invalid successor index");
999 return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
1000 : getDestOperandsFalseMutable());
1001}
1002
1003Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
1004 if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
1005 return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
1006 return nullptr;
1007}
1008
1009//===----------------------------------------------------------------------===//
1010// CaseOp
1011//===----------------------------------------------------------------------===//
1012
1013void cir::CaseOp::getSuccessorRegions(
1014 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1015 if (!point.isParent()) {
1016 regions.push_back(RegionSuccessor());
1017 return;
1018 }
1019 regions.push_back(RegionSuccessor(&getCaseRegion()));
1020}
1021
1022void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
1023 ArrayAttr value, CaseOpKind kind,
1024 OpBuilder::InsertPoint &insertPoint) {
1025 OpBuilder::InsertionGuard guardSwitch(builder);
1026 result.addAttribute("value", value);
1027 result.getOrAddProperties<Properties>().kind =
1028 cir::CaseOpKindAttr::get(builder.getContext(), kind);
1029 Region *caseRegion = result.addRegion();
1030 builder.createBlock(caseRegion);
1031
1032 insertPoint = builder.saveInsertionPoint();
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// SwitchOp
1037//===----------------------------------------------------------------------===//
1038
1039static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
1040 mlir::OpAsmParser::UnresolvedOperand &cond,
1041 mlir::Type &condType) {
1042 cir::IntType intCondType;
1043
1044 if (parser.parseLParen())
1045 return mlir::failure();
1046
1047 if (parser.parseOperand(result&: cond))
1048 return mlir::failure();
1049 if (parser.parseColon())
1050 return mlir::failure();
1051 if (parser.parseCustomTypeWithFallback(intCondType))
1052 return mlir::failure();
1053 condType = intCondType;
1054
1055 if (parser.parseRParen())
1056 return mlir::failure();
1057 if (parser.parseRegion(region&: regions, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
1058 return failure();
1059
1060 return mlir::success();
1061}
1062
1063static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
1064 mlir::Region &bodyRegion, mlir::Value condition,
1065 mlir::Type condType) {
1066 p << "(";
1067 p << condition;
1068 p << " : ";
1069 p.printStrippedAttrOrType(attrOrType: condType);
1070 p << ")";
1071
1072 p << ' ';
1073 p.printRegion(blocks&: bodyRegion, /*printEntryBlockArgs=*/false,
1074 /*printBlockTerminators=*/true);
1075}
1076
1077void cir::SwitchOp::getSuccessorRegions(
1078 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
1079 if (!point.isParent()) {
1080 region.push_back(RegionSuccessor());
1081 return;
1082 }
1083
1084 region.push_back(RegionSuccessor(&getBody()));
1085}
1086
1087void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
1088 Value cond, BuilderOpStateCallbackRef switchBuilder) {
1089 assert(switchBuilder && "the builder callback for regions must be present");
1090 OpBuilder::InsertionGuard guardSwitch(builder);
1091 Region *switchRegion = result.addRegion();
1092 builder.createBlock(switchRegion);
1093 result.addOperands({cond});
1094 switchBuilder(builder, result.location, result);
1095}
1096
1097void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
1098 walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1099 // Don't walk in nested switch op.
1100 if (isa<cir::SwitchOp>(op) && op != *this)
1101 return WalkResult::skip();
1102
1103 if (auto caseOp = dyn_cast<cir::CaseOp>(op))
1104 cases.push_back(caseOp);
1105
1106 return WalkResult::advance();
1107 });
1108}
1109
1110bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
1111 collectCases(cases);
1112
1113 if (getBody().empty())
1114 return false;
1115
1116 if (!isa<YieldOp>(getBody().front().back()))
1117 return false;
1118
1119 if (!llvm::all_of(getBody().front(),
1120 [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
1121 return false;
1122
1123 return llvm::all_of(cases, [this](CaseOp op) {
1124 return op->getParentOfType<SwitchOp>() == *this;
1125 });
1126}
1127
1128//===----------------------------------------------------------------------===//
1129// SwitchFlatOp
1130//===----------------------------------------------------------------------===//
1131
1132void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
1133 Value value, Block *defaultDestination,
1134 ValueRange defaultOperands,
1135 ArrayRef<APInt> caseValues,
1136 BlockRange caseDestinations,
1137 ArrayRef<ValueRange> caseOperands) {
1138
1139 std::vector<mlir::Attribute> caseValuesAttrs;
1140 for (const APInt &val : caseValues)
1141 caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
1142 mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
1143
1144 build(builder, result, value, defaultOperands, caseOperands, attrs,
1145 defaultDestination, caseDestinations);
1146}
1147
1148/// <cases> ::= `[` (case (`,` case )* )? `]`
1149/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
1150static ParseResult parseSwitchFlatOpCases(
1151 OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
1152 SmallVectorImpl<Block *> &caseDestinations,
1153 SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
1154 &caseOperands,
1155 SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
1156 if (failed(Result: parser.parseLSquare()))
1157 return failure();
1158 if (succeeded(Result: parser.parseOptionalRSquare()))
1159 return success();
1160 llvm::SmallVector<mlir::Attribute> values;
1161
1162 auto parseCase = [&]() {
1163 int64_t value = 0;
1164 if (failed(Result: parser.parseInteger(result&: value)))
1165 return failure();
1166
1167 values.push_back(cir::IntAttr::get(flagType, value));
1168
1169 Block *destination;
1170 llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
1171 llvm::SmallVector<Type> operandTypes;
1172 if (parser.parseColon() || parser.parseSuccessor(dest&: destination))
1173 return failure();
1174 if (!parser.parseOptionalLParen()) {
1175 if (parser.parseOperandList(result&: operands, delimiter: OpAsmParser::Delimiter::None,
1176 /*allowResultNumber=*/false) ||
1177 parser.parseColonTypeList(result&: operandTypes) || parser.parseRParen())
1178 return failure();
1179 }
1180 caseDestinations.push_back(Elt: destination);
1181 caseOperands.emplace_back(Args&: operands);
1182 caseOperandTypes.emplace_back(Args&: operandTypes);
1183 return success();
1184 };
1185 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: parseCase)))
1186 return failure();
1187
1188 caseValues = ArrayAttr::get(context: flagType.getContext(), value: values);
1189
1190 return parser.parseRSquare();
1191}
1192
1193static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
1194 Type flagType, mlir::ArrayAttr caseValues,
1195 SuccessorRange caseDestinations,
1196 OperandRangeRange caseOperands,
1197 const TypeRangeRange &caseOperandTypes) {
1198 p << '[';
1199 p.printNewline();
1200 if (!caseValues) {
1201 p << ']';
1202 return;
1203 }
1204
1205 size_t index = 0;
1206 llvm::interleave(
1207 c: llvm::zip(t&: caseValues, u&: caseDestinations),
1208 each_fn: [&](auto i) {
1209 p << " ";
1210 mlir::Attribute a = std::get<0>(i);
1211 p << mlir::cast<cir::IntAttr>(a).getValue();
1212 p << ": ";
1213 p.printSuccessorAndUseList(successor: std::get<1>(i), succOperands: caseOperands[index++]);
1214 },
1215 between_fn: [&] {
1216 p << ',';
1217 p.printNewline();
1218 });
1219 p.printNewline();
1220 p << ']';
1221}
1222
1223//===----------------------------------------------------------------------===//
1224// GlobalOp
1225//===----------------------------------------------------------------------===//
1226
1227static ParseResult parseConstantValue(OpAsmParser &parser,
1228 mlir::Attribute &valueAttr) {
1229 NamedAttrList attr;
1230 return parser.parseAttribute(result&: valueAttr, attrName: "value", attrs&: attr);
1231}
1232
1233static void printConstant(OpAsmPrinter &p, Attribute value) {
1234 p.printAttribute(attr: value);
1235}
1236
1237mlir::LogicalResult cir::GlobalOp::verify() {
1238 // Verify that the initial value, if present, is either a unit attribute or
1239 // an attribute CIR supports.
1240 if (getInitialValue().has_value()) {
1241 if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
1242 .failed())
1243 return failure();
1244 }
1245
1246 // TODO(CIR): Many other checks for properties that haven't been upstreamed
1247 // yet.
1248
1249 return success();
1250}
1251
1252void cir::GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1253 llvm::StringRef sym_name, mlir::Type sym_type,
1254 cir::GlobalLinkageKind linkage) {
1255 odsState.addAttribute(getSymNameAttrName(odsState.name),
1256 odsBuilder.getStringAttr(sym_name));
1257 odsState.addAttribute(getSymTypeAttrName(odsState.name),
1258 mlir::TypeAttr::get(sym_type));
1259
1260 cir::GlobalLinkageKindAttr linkageAttr =
1261 cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
1262 odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);
1263
1264 odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
1265 cir::VisibilityAttr::get(odsBuilder.getContext()));
1266}
1267
1268static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
1269 TypeAttr type,
1270 Attribute initAttr) {
1271 if (!op.isDeclaration()) {
1272 p << "= ";
1273 // This also prints the type...
1274 if (initAttr)
1275 printConstant(p, value: initAttr);
1276 } else {
1277 p << ": " << type;
1278 }
1279}
1280
1281static ParseResult
1282parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1283 Attribute &initialValueAttr) {
1284 mlir::Type opTy;
1285 if (parser.parseOptionalEqual().failed()) {
1286 // Absence of equal means a declaration, so we need to parse the type.
1287 // cir.global @a : !cir.int<s, 32>
1288 if (parser.parseColonType(result&: opTy))
1289 return failure();
1290 } else {
1291 // Parse constant with initializer, examples:
1292 // cir.global @y = #cir.fp<1.250000e+00> : !cir.double
1293 // cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
1294 if (parseConstantValue(parser, valueAttr&: initialValueAttr).failed())
1295 return failure();
1296
1297 assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
1298 "Non-typed attrs shouldn't appear here.");
1299 auto typedAttr = mlir::cast<mlir::TypedAttr>(Val&: initialValueAttr);
1300 opTy = typedAttr.getType();
1301 }
1302
1303 typeAttr = TypeAttr::get(type: opTy);
1304 return success();
1305}
1306
1307//===----------------------------------------------------------------------===//
1308// GetGlobalOp
1309//===----------------------------------------------------------------------===//
1310
1311LogicalResult
1312cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1313 // Verify that the result type underlying pointer type matches the type of
1314 // the referenced cir.global or cir.func op.
1315 mlir::Operation *op =
1316 symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
1317 if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
1318 return emitOpError("'")
1319 << getName()
1320 << "' does not reference a valid cir.global or cir.func";
1321
1322 mlir::Type symTy;
1323 if (auto g = dyn_cast<GlobalOp>(op)) {
1324 symTy = g.getSymType();
1325 assert(!cir::MissingFeatures::addressSpace());
1326 assert(!cir::MissingFeatures::opGlobalThreadLocal());
1327 } else if (auto f = dyn_cast<FuncOp>(op)) {
1328 symTy = f.getFunctionType();
1329 } else {
1330 llvm_unreachable("Unexpected operation for GetGlobalOp");
1331 }
1332
1333 auto resultType = dyn_cast<PointerType>(getAddr().getType());
1334 if (!resultType || symTy != resultType.getPointee())
1335 return emitOpError("result type pointee type '")
1336 << resultType.getPointee() << "' does not match type " << symTy
1337 << " of the global @" << getName();
1338
1339 return success();
1340}
1341
1342//===----------------------------------------------------------------------===//
1343// FuncOp
1344//===----------------------------------------------------------------------===//
1345
1346/// Returns the name used for the linkage attribute. This *must* correspond to
1347/// the name of the attribute in ODS.
1348static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }
1349
1350void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
1351 StringRef name, FuncType type,
1352 GlobalLinkageKind linkage) {
1353 result.addRegion();
1354 result.addAttribute(SymbolTable::getSymbolAttrName(),
1355 builder.getStringAttr(name));
1356 result.addAttribute(getFunctionTypeAttrName(result.name),
1357 TypeAttr::get(type));
1358 result.addAttribute(
1359 getLinkageAttrNameString(),
1360 GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1361 result.addAttribute(getGlobalVisibilityAttrName(result.name),
1362 cir::VisibilityAttr::get(builder.getContext()));
1363}
1364
1365ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
1366 llvm::SMLoc loc = parser.getCurrentLocation();
1367 mlir::Builder &builder = parser.getBuilder();
1368
1369 mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
1370 mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
1371 mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
1372
1373 // Default to external linkage if no keyword is provided.
1374 state.addAttribute(getLinkageAttrNameString(),
1375 GlobalLinkageKindAttr::get(
1376 parser.getContext(),
1377 parseOptionalCIRKeyword<GlobalLinkageKind>(
1378 parser, GlobalLinkageKind::ExternalLinkage)));
1379
1380 ::llvm::StringRef visAttrStr;
1381 if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
1382 .succeeded()) {
1383 state.addAttribute(visNameAttr,
1384 parser.getBuilder().getStringAttr(visAttrStr));
1385 }
1386
1387 cir::VisibilityAttr cirVisibilityAttr;
1388 parseVisibilityAttr(parser, cirVisibilityAttr);
1389 state.addAttribute(visibilityNameAttr, cirVisibilityAttr);
1390
1391 if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
1392 state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());
1393
1394 StringAttr nameAttr;
1395 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1396 state.attributes))
1397 return failure();
1398 llvm::SmallVector<OpAsmParser::Argument, 8> arguments;
1399 llvm::SmallVector<mlir::Type> resultTypes;
1400 llvm::SmallVector<DictionaryAttr> resultAttrs;
1401 bool isVariadic = false;
1402 if (function_interface_impl::parseFunctionSignatureWithArguments(
1403 parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
1404 resultAttrs))
1405 return failure();
1406 llvm::SmallVector<mlir::Type> argTypes;
1407 for (OpAsmParser::Argument &arg : arguments)
1408 argTypes.push_back(arg.type);
1409
1410 if (resultTypes.size() > 1) {
1411 return parser.emitError(
1412 loc, "functions with multiple return types are not supported");
1413 }
1414
1415 mlir::Type returnType =
1416 (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
1417 : resultTypes.front());
1418
1419 cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
1420 if (!fnType)
1421 return failure();
1422 state.addAttribute(getFunctionTypeAttrName(state.name),
1423 TypeAttr::get(fnType));
1424
1425 bool hasAlias = false;
1426 mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
1427 if (parser.parseOptionalKeyword("alias").succeeded()) {
1428 if (parser.parseLParen().failed())
1429 return failure();
1430 mlir::StringAttr aliaseeAttr;
1431 if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
1432 return failure();
1433 state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
1434 if (parser.parseRParen().failed())
1435 return failure();
1436 hasAlias = true;
1437 }
1438
1439 // Parse the optional function body.
1440 auto *body = state.addRegion();
1441 OptionalParseResult parseResult = parser.parseOptionalRegion(
1442 *body, arguments, /*enableNameShadowing=*/false);
1443 if (parseResult.has_value()) {
1444 if (hasAlias)
1445 return parser.emitError(loc, "function alias shall not have a body");
1446 if (failed(*parseResult))
1447 return failure();
1448 // Function body was parsed, make sure its not empty.
1449 if (body->empty())
1450 return parser.emitError(loc, "expected non-empty function body");
1451 }
1452
1453 return success();
1454}
1455
1456// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
1457// have a similar implementation. We don't currently ifuncs or materializable
1458// functions, but those should be handled here as they are implemented.
1459bool cir::FuncOp::isDeclaration() {
1460 assert(!cir::MissingFeatures::supportIFuncAttr());
1461
1462 std::optional<StringRef> aliasee = getAliasee();
1463 if (!aliasee)
1464 return getFunctionBody().empty();
1465
1466 // Aliases are always definitions.
1467 return false;
1468}
1469
1470mlir::Region *cir::FuncOp::getCallableRegion() {
1471 // TODO(CIR): This function will have special handling for aliases and a
1472 // check for an external function, once those features have been upstreamed.
1473 return &getBody();
1474}
1475
1476void cir::FuncOp::print(OpAsmPrinter &p) {
1477 if (getComdat())
1478 p << " comdat";
1479
1480 if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
1481 p << ' ' << stringifyGlobalLinkageKind(getLinkage());
1482
1483 mlir::SymbolTable::Visibility vis = getVisibility();
1484 if (vis != mlir::SymbolTable::Visibility::Public)
1485 p << ' ' << vis;
1486
1487 cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr();
1488 if (!cirVisibilityAttr.isDefault()) {
1489 p << ' ';
1490 printVisibilityAttr(p, cirVisibilityAttr);
1491 }
1492
1493 if (getDsoLocal())
1494 p << " dso_local";
1495
1496 p << ' ';
1497 p.printSymbolName(getSymName());
1498 cir::FuncType fnType = getFunctionType();
1499 function_interface_impl::printFunctionSignature(
1500 p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
1501
1502 if (std::optional<StringRef> aliaseeName = getAliasee()) {
1503 p << " alias(";
1504 p.printSymbolName(*aliaseeName);
1505 p << ")";
1506 }
1507
1508 // Print the body if this is not an external function.
1509 Region &body = getOperation()->getRegion(0);
1510 if (!body.empty()) {
1511 p << ' ';
1512 p.printRegion(body, /*printEntryBlockArgs=*/false,
1513 /*printBlockTerminators=*/true);
1514 }
1515}
1516
1517// TODO(CIR): The properties of functions that require verification haven't
1518// been implemented yet.
1519mlir::LogicalResult cir::FuncOp::verify() { return success(); }
1520
1521//===----------------------------------------------------------------------===//
1522// BinOp
1523//===----------------------------------------------------------------------===//
1524LogicalResult cir::BinOp::verify() {
1525 bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
1526 bool saturated = getSaturated();
1527
1528 if (!isa<cir::IntType>(getType()) && noWrap)
1529 return emitError()
1530 << "only operations on integer values may have nsw/nuw flags";
1531
1532 bool noWrapOps = getKind() == cir::BinOpKind::Add ||
1533 getKind() == cir::BinOpKind::Sub ||
1534 getKind() == cir::BinOpKind::Mul;
1535
1536 bool saturatedOps =
1537 getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;
1538
1539 if (noWrap && !noWrapOps)
1540 return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
1541 "'sub' and 'mul'";
1542 if (saturated && !saturatedOps)
1543 return emitError() << "The saturated flag is applicable to opcodes: 'add' "
1544 "and 'sub'";
1545 if (noWrap && saturated)
1546 return emitError() << "The nsw/nuw flags and the saturated flag are "
1547 "mutually exclusive";
1548
1549 assert(!cir::MissingFeatures::complexType());
1550 // TODO(cir): verify for complex binops
1551
1552 return mlir::success();
1553}
1554
1555//===----------------------------------------------------------------------===//
1556// TernaryOp
1557//===----------------------------------------------------------------------===//
1558
1559/// Given the region at `point`, or the parent operation if `point` is None,
1560/// return the successor regions. These are the regions that may be selected
1561/// during the flow of control. `operands` is a set of optional attributes that
1562/// correspond to a constant value for each operand, or null if that operand is
1563/// not a constant.
1564void cir::TernaryOp::getSuccessorRegions(
1565 mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1566 // The `true` and the `false` region branch back to the parent operation.
1567 if (!point.isParent()) {
1568 regions.push_back(RegionSuccessor(this->getODSResults(0)));
1569 return;
1570 }
1571
1572 // When branching from the parent operation, both the true and false
1573 // regions are considered possible successors
1574 regions.push_back(RegionSuccessor(&getTrueRegion()));
1575 regions.push_back(RegionSuccessor(&getFalseRegion()));
1576}
1577
1578void cir::TernaryOp::build(
1579 OpBuilder &builder, OperationState &result, Value cond,
1580 function_ref<void(OpBuilder &, Location)> trueBuilder,
1581 function_ref<void(OpBuilder &, Location)> falseBuilder) {
1582 result.addOperands(cond);
1583 OpBuilder::InsertionGuard guard(builder);
1584 Region *trueRegion = result.addRegion();
1585 Block *block = builder.createBlock(trueRegion);
1586 trueBuilder(builder, result.location);
1587 Region *falseRegion = result.addRegion();
1588 builder.createBlock(falseRegion);
1589 falseBuilder(builder, result.location);
1590
1591 auto yield = dyn_cast<YieldOp>(block->getTerminator());
1592 assert((yield && yield.getNumOperands() <= 1) &&
1593 "expected zero or one result type");
1594 if (yield.getNumOperands() == 1)
1595 result.addTypes(TypeRange{yield.getOperandTypes().front()});
1596}
1597
1598//===----------------------------------------------------------------------===//
1599// SelectOp
1600//===----------------------------------------------------------------------===//
1601
1602OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
1603 mlir::Attribute condition = adaptor.getCondition();
1604 if (condition) {
1605 bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
1606 return conditionValue ? getTrueValue() : getFalseValue();
1607 }
1608
1609 // cir.select if %0 then x else x -> x
1610 mlir::Attribute trueValue = adaptor.getTrueValue();
1611 mlir::Attribute falseValue = adaptor.getFalseValue();
1612 if (trueValue == falseValue)
1613 return trueValue;
1614 if (getTrueValue() == getFalseValue())
1615 return getTrueValue();
1616
1617 return {};
1618}
1619
1620//===----------------------------------------------------------------------===//
1621// ShiftOp
1622//===----------------------------------------------------------------------===//
1623LogicalResult cir::ShiftOp::verify() {
1624 mlir::Operation *op = getOperation();
1625 auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
1626 auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
1627 if (!op0VecTy ^ !op1VecTy)
1628 return emitOpError() << "input types cannot be one vector and one scalar";
1629
1630 if (op0VecTy) {
1631 if (op0VecTy.getSize() != op1VecTy.getSize())
1632 return emitOpError() << "input vector types must have the same size";
1633
1634 auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
1635 if (!opResultTy)
1636 return emitOpError() << "the type of the result must be a vector "
1637 << "if it is vector shift";
1638
1639 auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
1640 auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
1641 if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
1642 return emitOpError()
1643 << "vector operands do not have the same elements sizes";
1644
1645 auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
1646 if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
1647 return emitOpError() << "vector operands and result type do not have the "
1648 "same elements sizes";
1649 }
1650
1651 return mlir::success();
1652}
1653
1654//===----------------------------------------------------------------------===//
1655// UnaryOp
1656//===----------------------------------------------------------------------===//
1657
1658LogicalResult cir::UnaryOp::verify() {
1659 switch (getKind()) {
1660 case cir::UnaryOpKind::Inc:
1661 case cir::UnaryOpKind::Dec:
1662 case cir::UnaryOpKind::Plus:
1663 case cir::UnaryOpKind::Minus:
1664 case cir::UnaryOpKind::Not:
1665 // Nothing to verify.
1666 return success();
1667 }
1668
1669 llvm_unreachable("Unknown UnaryOp kind?");
1670}
1671
1672static bool isBoolNot(cir::UnaryOp op) {
1673 return isa<cir::BoolType>(op.getInput().getType()) &&
1674 op.getKind() == cir::UnaryOpKind::Not;
1675}
1676
1677// This folder simplifies the sequential boolean not operations.
1678// For instance, the next two unary operations will be eliminated:
1679//
1680// ```mlir
1681// %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
1682// %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
1683// ```
1684//
1685// and the argument of the first one (%0) will be used instead.
1686OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
1687 if (isBoolNot(*this))
1688 if (auto previous = dyn_cast_or_null<UnaryOp>(getInput().getDefiningOp()))
1689 if (isBoolNot(previous))
1690 return previous.getInput();
1691
1692 return {};
1693}
1694
1695//===----------------------------------------------------------------------===//
1696// GetMemberOp Definitions
1697//===----------------------------------------------------------------------===//
1698
1699LogicalResult cir::GetMemberOp::verify() {
1700 const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
1701 if (!recordTy)
1702 return emitError() << "expected pointer to a record type";
1703
1704 if (recordTy.getMembers().size() <= getIndex())
1705 return emitError() << "member index out of bounds";
1706
1707 if (recordTy.getMembers()[getIndex()] != getType().getPointee())
1708 return emitError() << "member type mismatch";
1709
1710 return mlir::success();
1711}
1712
1713//===----------------------------------------------------------------------===//
1714// VecCreateOp
1715//===----------------------------------------------------------------------===//
1716
1717OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
1718 if (llvm::any_of(getElements(), [](mlir::Value value) {
1719 return !mlir::isa<cir::ConstantOp>(value.getDefiningOp());
1720 }))
1721 return {};
1722
1723 return cir::ConstVectorAttr::get(
1724 getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
1725}
1726
1727LogicalResult cir::VecCreateOp::verify() {
1728 // Verify that the number of arguments matches the number of elements in the
1729 // vector, and that the type of all the arguments matches the type of the
1730 // elements in the vector.
1731 const cir::VectorType vecTy = getType();
1732 if (getElements().size() != vecTy.getSize()) {
1733 return emitOpError() << "operand count of " << getElements().size()
1734 << " doesn't match vector type " << vecTy
1735 << " element count of " << vecTy.getSize();
1736 }
1737
1738 const mlir::Type elementType = vecTy.getElementType();
1739 for (const mlir::Value element : getElements()) {
1740 if (element.getType() != elementType) {
1741 return emitOpError() << "operand type " << element.getType()
1742 << " doesn't match vector element type "
1743 << elementType;
1744 }
1745 }
1746
1747 return success();
1748}
1749
1750//===----------------------------------------------------------------------===//
1751// VecExtractOp
1752//===----------------------------------------------------------------------===//
1753
1754OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
1755 const auto vectorAttr =
1756 llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
1757 if (!vectorAttr)
1758 return {};
1759
1760 const auto indexAttr =
1761 llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
1762 if (!indexAttr)
1763 return {};
1764
1765 const mlir::ArrayAttr elements = vectorAttr.getElts();
1766 const uint64_t index = indexAttr.getUInt();
1767 if (index >= elements.size())
1768 return {};
1769
1770 return elements[index];
1771}
1772
1773//===----------------------------------------------------------------------===//
1774// VecCmpOp
1775//===----------------------------------------------------------------------===//
1776
1777OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1778 auto lhsVecAttr =
1779 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
1780 auto rhsVecAttr =
1781 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
1782 if (!lhsVecAttr || !rhsVecAttr)
1783 return {};
1784
1785 mlir::Type inputElemTy =
1786 mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
1787 if (!isAnyIntegerOrFloatingPointType(inputElemTy))
1788 return {};
1789
1790 cir::CmpOpKind opKind = adaptor.getKind();
1791 mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
1792 mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
1793 uint64_t vecSize = lhsVecElhs.size();
1794
1795 SmallVector<mlir::Attribute, 16> elements(vecSize);
1796 bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
1797 for (uint64_t i = 0; i < vecSize; i++) {
1798 mlir::Attribute lhsAttr = lhsVecElhs[i];
1799 mlir::Attribute rhsAttr = rhsVecElhs[i];
1800 int cmpResult = 0;
1801 switch (opKind) {
1802 case cir::CmpOpKind::lt: {
1803 if (isIntAttr) {
1804 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
1805 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1806 } else {
1807 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
1808 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1809 }
1810 break;
1811 }
1812 case cir::CmpOpKind::le: {
1813 if (isIntAttr) {
1814 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
1815 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1816 } else {
1817 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
1818 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1819 }
1820 break;
1821 }
1822 case cir::CmpOpKind::gt: {
1823 if (isIntAttr) {
1824 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
1825 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1826 } else {
1827 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
1828 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1829 }
1830 break;
1831 }
1832 case cir::CmpOpKind::ge: {
1833 if (isIntAttr) {
1834 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
1835 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1836 } else {
1837 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
1838 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1839 }
1840 break;
1841 }
1842 case cir::CmpOpKind::eq: {
1843 if (isIntAttr) {
1844 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
1845 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1846 } else {
1847 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
1848 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1849 }
1850 break;
1851 }
1852 case cir::CmpOpKind::ne: {
1853 if (isIntAttr) {
1854 cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
1855 mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
1856 } else {
1857 cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
1858 mlir::cast<cir::FPAttr>(rhsAttr).getValue();
1859 }
1860 break;
1861 }
1862 }
1863
1864 elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
1865 }
1866
1867 return cir::ConstVectorAttr::get(
1868 getType(), mlir::ArrayAttr::get(getContext(), elements));
1869}
1870
1871//===----------------------------------------------------------------------===//
1872// VecShuffleOp
1873//===----------------------------------------------------------------------===//
1874
1875OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
1876 auto vec1Attr =
1877 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
1878 auto vec2Attr =
1879 mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
1880 if (!vec1Attr || !vec2Attr)
1881 return {};
1882
1883 mlir::Type vec1ElemTy =
1884 mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();
1885
1886 mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
1887 mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
1888 mlir::ArrayAttr indicesElts = adaptor.getIndices();
1889
1890 SmallVector<mlir::Attribute, 16> elements;
1891 elements.reserve(indicesElts.size());
1892
1893 uint64_t vec1Size = vec1Elts.size();
1894 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
1895 if (idxAttr.getSInt() == -1) {
1896 elements.push_back(cir::UndefAttr::get(vec1ElemTy));
1897 continue;
1898 }
1899
1900 uint64_t idxValue = idxAttr.getUInt();
1901 elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
1902 : vec2Elts[idxValue - vec1Size]);
1903 }
1904
1905 return cir::ConstVectorAttr::get(
1906 getType(), mlir::ArrayAttr::get(getContext(), elements));
1907}
1908
1909LogicalResult cir::VecShuffleOp::verify() {
1910 // The number of elements in the indices array must match the number of
1911 // elements in the result type.
1912 if (getIndices().size() != getResult().getType().getSize()) {
1913 return emitOpError() << ": the number of elements in " << getIndices()
1914 << " and " << getResult().getType() << " don't match";
1915 }
1916
1917 // The element types of the two input vectors and of the result type must
1918 // match.
1919 if (getVec1().getType().getElementType() !=
1920 getResult().getType().getElementType()) {
1921 return emitOpError() << ": element types of " << getVec1().getType()
1922 << " and " << getResult().getType() << " don't match";
1923 }
1924
1925 const uint64_t maxValidIndex =
1926 getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
1927 if (llvm::any_of(
1928 getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
1929 return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
1930 })) {
1931 return emitOpError() << ": index for __builtin_shufflevector must be "
1932 "less than the total number of vector elements";
1933 }
1934 return success();
1935}
1936
1937//===----------------------------------------------------------------------===//
1938// VecShuffleDynamicOp
1939//===----------------------------------------------------------------------===//
1940
1941OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
1942 mlir::Attribute vec = adaptor.getVec();
1943 mlir::Attribute indices = adaptor.getIndices();
1944 if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
1945 mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
1946 auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
1947 auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);
1948
1949 mlir::ArrayAttr vecElts = vecAttr.getElts();
1950 mlir::ArrayAttr indicesElts = indicesAttr.getElts();
1951
1952 const uint64_t numElements = vecElts.size();
1953
1954 SmallVector<mlir::Attribute, 16> elements;
1955 elements.reserve(numElements);
1956
1957 const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
1958 for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
1959 uint64_t idxValue = idxAttr.getUInt();
1960 uint64_t newIdx = idxValue & maskBits;
1961 elements.push_back(vecElts[newIdx]);
1962 }
1963
1964 return cir::ConstVectorAttr::get(
1965 getType(), mlir::ArrayAttr::get(getContext(), elements));
1966 }
1967
1968 return {};
1969}
1970
1971LogicalResult cir::VecShuffleDynamicOp::verify() {
1972 // The number of elements in the two input vectors must match.
1973 if (getVec().getType().getSize() !=
1974 mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
1975 return emitOpError() << ": the number of elements in " << getVec().getType()
1976 << " and " << getIndices().getType() << " don't match";
1977 }
1978 return success();
1979}
1980
1981//===----------------------------------------------------------------------===//
1982// VecTernaryOp
1983//===----------------------------------------------------------------------===//
1984
1985LogicalResult cir::VecTernaryOp::verify() {
1986 // Verify that the condition operand has the same number of elements as the
1987 // other operands. (The automatic verification already checked that all
1988 // operands are vector types and that the second and third operands are the
1989 // same type.)
1990 if (getCond().getType().getSize() != getLhs().getType().getSize()) {
1991 return emitOpError() << ": the number of elements in "
1992 << getCond().getType() << " and " << getLhs().getType()
1993 << " don't match";
1994 }
1995 return success();
1996}
1997
1998OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
1999 mlir::Attribute cond = adaptor.getCond();
2000 mlir::Attribute lhs = adaptor.getLhs();
2001 mlir::Attribute rhs = adaptor.getRhs();
2002
2003 if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
2004 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
2005 !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
2006 return {};
2007 auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
2008 auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
2009 auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
2010
2011 mlir::ArrayAttr condElts = condVec.getElts();
2012
2013 SmallVector<mlir::Attribute, 16> elements;
2014 elements.reserve(condElts.size());
2015
2016 for (const auto &[idx, condAttr] :
2017 llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
2018 if (condAttr.getSInt()) {
2019 elements.push_back(lhsVec.getElts()[idx]);
2020 } else {
2021 elements.push_back(rhsVec.getElts()[idx]);
2022 }
2023 }
2024
2025 cir::VectorType vecTy = getLhs().getType();
2026 return cir::ConstVectorAttr::get(
2027 vecTy, mlir::ArrayAttr::get(getContext(), elements));
2028}
2029
2030//===----------------------------------------------------------------------===//
2031// ComplexCreateOp
2032//===----------------------------------------------------------------------===//
2033
2034LogicalResult cir::ComplexCreateOp::verify() {
2035 if (getType().getElementType() != getReal().getType()) {
2036 emitOpError()
2037 << "operand type of cir.complex.create does not match its result type";
2038 return failure();
2039 }
2040
2041 return success();
2042}
2043
2044OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
2045 mlir::Attribute real = adaptor.getReal();
2046 mlir::Attribute imag = adaptor.getImag();
2047 if (!real || !imag)
2048 return {};
2049
2050 // When both of real and imag are constants, we can fold the operation into an
2051 // `#cir.const_complex` operation.
2052 auto realAttr = mlir::cast<mlir::TypedAttr>(real);
2053 auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
2054 return cir::ConstComplexAttr::get(realAttr, imagAttr);
2055}
2056
2057//===----------------------------------------------------------------------===//
2058// ComplexRealOp
2059//===----------------------------------------------------------------------===//
2060
2061LogicalResult cir::ComplexRealOp::verify() {
2062 if (getType() != getOperand().getType().getElementType()) {
2063 emitOpError() << ": result type does not match operand type";
2064 return failure();
2065 }
2066 return success();
2067}
2068
2069OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
2070 if (auto complexCreateOp =
2071 dyn_cast_or_null<cir::ComplexCreateOp>(getOperand().getDefiningOp()))
2072 return complexCreateOp.getOperand(0);
2073
2074 auto complex =
2075 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2076 return complex ? complex.getReal() : nullptr;
2077}
2078
2079//===----------------------------------------------------------------------===//
2080// ComplexImagOp
2081//===----------------------------------------------------------------------===//
2082
2083LogicalResult cir::ComplexImagOp::verify() {
2084 if (getType() != getOperand().getType().getElementType()) {
2085 emitOpError() << ": result type does not match operand type";
2086 return failure();
2087 }
2088 return success();
2089}
2090
2091OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
2092 if (auto complexCreateOp =
2093 dyn_cast_or_null<cir::ComplexCreateOp>(getOperand().getDefiningOp()))
2094 return complexCreateOp.getOperand(1);
2095
2096 auto complex =
2097 mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
2098 return complex ? complex.getImag() : nullptr;
2099}
2100
2101//===----------------------------------------------------------------------===//
2102// ComplexRealPtrOp
2103//===----------------------------------------------------------------------===//
2104
2105LogicalResult cir::ComplexRealPtrOp::verify() {
2106 mlir::Type resultPointeeTy = getType().getPointee();
2107 cir::PointerType operandPtrTy = getOperand().getType();
2108 auto operandPointeeTy =
2109 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2110
2111 if (resultPointeeTy != operandPointeeTy.getElementType()) {
2112 return emitOpError() << ": result type does not match operand type";
2113 }
2114
2115 return success();
2116}
2117
2118//===----------------------------------------------------------------------===//
2119// ComplexImagPtrOp
2120//===----------------------------------------------------------------------===//
2121
2122LogicalResult cir::ComplexImagPtrOp::verify() {
2123 mlir::Type resultPointeeTy = getType().getPointee();
2124 cir::PointerType operandPtrTy = getOperand().getType();
2125 auto operandPointeeTy =
2126 mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2127
2128 if (resultPointeeTy != operandPointeeTy.getElementType()) {
2129 return emitOpError()
2130 << "cir.complex.imag_ptr result type does not match operand type";
2131 }
2132 return success();
2133}
2134
2135//===----------------------------------------------------------------------===//
2136// TableGen'd op method definitions
2137//===----------------------------------------------------------------------===//
2138
2139#define GET_OP_CLASSES
2140#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
2141

source code of clang/lib/CIR/Dialect/IR/CIRDialect.cpp