1//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the LLVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/DialectImplementation.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/Interfaces/FunctionImplementation.h"
25#include "mlir/Transforms/InliningUtils.h"
26
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/IR/Function.h"
29#include "llvm/IR/Type.h"
30#include "llvm/Support/Error.h"
31
32#include <numeric>
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::LLVM;
37using mlir::LLVM::cconv::getMaxEnumValForCConv;
38using mlir::LLVM::linkage::getMaxEnumValForLinkage;
39using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
40
41#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
42
43//===----------------------------------------------------------------------===//
44// Attribute Helpers
45//===----------------------------------------------------------------------===//
46
47static constexpr const char kElemTypeAttrName[] = "elem_type";
48
49static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
50 SmallVector<NamedAttribute, 8> filteredAttrs(
51 llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) {
52 if (attr.getName() == "fastmathFlags") {
53 auto defAttr =
54 FastmathFlagsAttr::get(context: attr.getValue().getContext(), value: {});
55 return defAttr != attr.getValue();
56 }
57 return true;
58 }));
59 return filteredAttrs;
60}
61
62/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
63/// fully defined llvm.func.
64static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
65 Operation *op,
66 SymbolTableCollection &symbolTable) {
67 StringRef name = symbol.getValue();
68 auto func =
69 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(from: op, symbol: symbol.getAttr());
70 if (!func)
71 return op->emitOpError(message: "'")
72 << name << "' does not reference a valid LLVM function";
73 if (func.isExternal())
74 return op->emitOpError(message: "'") << name << "' does not have a definition";
75 return success();
76}
77
78/// Returns a boolean type that has the same shape as `type`. It supports both
79/// fixed size vectors as well as scalable vectors.
80static Type getI1SameShape(Type type) {
81 Type i1Type = IntegerType::get(context: type.getContext(), width: 1);
82 if (LLVM::isCompatibleVectorType(type))
83 return LLVM::getVectorType(elementType: i1Type, numElements: LLVM::getVectorNumElements(type));
84 return i1Type;
85}
86
87// Parses one of the keywords provided in the list `keywords` and returns the
88// position of the parsed keyword in the list. If none of the keywords from the
89// list is parsed, returns -1.
90static int parseOptionalKeywordAlternative(OpAsmParser &parser,
91 ArrayRef<StringRef> keywords) {
92 for (const auto &en : llvm::enumerate(First&: keywords)) {
93 if (succeeded(Result: parser.parseOptionalKeyword(keyword: en.value())))
94 return en.index();
95 }
96 return -1;
97}
98
99namespace {
100template <typename Ty>
101struct EnumTraits {};
102
103#define REGISTER_ENUM_TYPE(Ty) \
104 template <> \
105 struct EnumTraits<Ty> { \
106 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
107 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
108 }
109
110REGISTER_ENUM_TYPE(Linkage);
111REGISTER_ENUM_TYPE(UnnamedAddr);
112REGISTER_ENUM_TYPE(CConv);
113REGISTER_ENUM_TYPE(TailCallKind);
114REGISTER_ENUM_TYPE(Visibility);
115} // namespace
116
117/// Parse an enum from the keyword, or default to the provided default value.
118/// The return type is the enum type by default, unless overridden with the
119/// second template argument.
120template <typename EnumTy, typename RetTy = EnumTy>
121static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
122 EnumTy defaultValue) {
123 SmallVector<StringRef, 10> names;
124 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
125 names.push_back(Elt: EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
126
127 int index = parseOptionalKeywordAlternative(parser, keywords: names);
128 if (index == -1)
129 return static_cast<RetTy>(defaultValue);
130 return static_cast<RetTy>(index);
131}
132
133//===----------------------------------------------------------------------===//
134// Operand bundle helpers.
135//===----------------------------------------------------------------------===//
136
137static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
138 TypeRange operandTypes, StringRef tag) {
139 p.printString(string: tag);
140 p << "(";
141
142 if (!operands.empty()) {
143 p.printOperands(container: operands);
144 p << " : ";
145 llvm::interleaveComma(c: operandTypes, os&: p);
146 }
147
148 p << ")";
149}
150
151static void printOpBundles(OpAsmPrinter &p, Operation *op,
152 OperandRangeRange opBundleOperands,
153 TypeRangeRange opBundleOperandTypes,
154 std::optional<ArrayAttr> opBundleTags) {
155 if (opBundleOperands.empty())
156 return;
157 assert(opBundleTags && "expect operand bundle tags");
158
159 p << "[";
160 llvm::interleaveComma(
161 c: llvm::zip(t&: opBundleOperands, u&: opBundleOperandTypes, args&: *opBundleTags), os&: p,
162 each_fn: [&p](auto bundle) {
163 auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
164 printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
165 bundleTag);
166 });
167 p << "]";
168}
169
170static ParseResult parseOneOpBundle(
171 OpAsmParser &p,
172 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
173 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
174 SmallVector<Attribute> &opBundleTags) {
175 SMLoc currentParserLoc = p.getCurrentLocation();
176 SmallVector<OpAsmParser::UnresolvedOperand> operands;
177 SmallVector<Type> types;
178 std::string tag;
179
180 if (p.parseString(string: &tag))
181 return p.emitError(loc: currentParserLoc, message: "expect operand bundle tag");
182
183 if (p.parseLParen())
184 return failure();
185
186 if (p.parseOptionalRParen()) {
187 if (p.parseOperandList(result&: operands) || p.parseColon() ||
188 p.parseTypeList(result&: types) || p.parseRParen())
189 return failure();
190 }
191
192 opBundleOperands.push_back(Elt: std::move(operands));
193 opBundleOperandTypes.push_back(Elt: std::move(types));
194 opBundleTags.push_back(Elt: StringAttr::get(context: p.getContext(), bytes: tag));
195
196 return success();
197}
198
199static std::optional<ParseResult> parseOpBundles(
200 OpAsmParser &p,
201 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
202 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
203 ArrayAttr &opBundleTags) {
204 if (p.parseOptionalLSquare())
205 return std::nullopt;
206
207 if (succeeded(Result: p.parseOptionalRSquare()))
208 return success();
209
210 SmallVector<Attribute> opBundleTagAttrs;
211 auto bundleParser = [&] {
212 return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
213 opBundleTags&: opBundleTagAttrs);
214 };
215 if (p.parseCommaSeparatedList(parseElementFn: bundleParser))
216 return failure();
217
218 if (p.parseRSquare())
219 return failure();
220
221 opBundleTags = ArrayAttr::get(context: p.getContext(), value: opBundleTagAttrs);
222
223 return success();
224}
225
226//===----------------------------------------------------------------------===//
227// Printing, parsing, folding and builder for LLVM::CmpOp.
228//===----------------------------------------------------------------------===//
229
230void ICmpOp::print(OpAsmPrinter &p) {
231 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(i: 0)
232 << ", " << getOperand(i: 1);
233 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), elidedAttrs: {"predicate"});
234 p << " : " << getLhs().getType();
235}
236
237void FCmpOp::print(OpAsmPrinter &p) {
238 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(i: 0)
239 << ", " << getOperand(i: 1);
240 p.printOptionalAttrDict(attrs: processFMFAttr(attrs: (*this)->getAttrs()), elidedAttrs: {"predicate"});
241 p << " : " << getLhs().getType();
242}
243
244// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
245// attribute-dict? `:` type
246// <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
247// attribute-dict? `:` type
248template <typename CmpPredicateType>
249static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
250 StringAttr predicateAttr;
251 OpAsmParser::UnresolvedOperand lhs, rhs;
252 Type type;
253 SMLoc predicateLoc, trailingTypeLoc;
254 if (parser.getCurrentLocation(loc: &predicateLoc) ||
255 parser.parseAttribute(result&: predicateAttr, attrName: "predicate", attrs&: result.attributes) ||
256 parser.parseOperand(result&: lhs) || parser.parseComma() ||
257 parser.parseOperand(result&: rhs) ||
258 parser.parseOptionalAttrDict(result&: result.attributes) || parser.parseColon() ||
259 parser.getCurrentLocation(loc: &trailingTypeLoc) || parser.parseType(result&: type) ||
260 parser.resolveOperand(operand: lhs, type, result&: result.operands) ||
261 parser.resolveOperand(operand: rhs, type, result&: result.operands))
262 return failure();
263
264 // Replace the string attribute `predicate` with an integer attribute.
265 int64_t predicateValue = 0;
266 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
267 std::optional<ICmpPredicate> predicate =
268 symbolizeICmpPredicate(predicateAttr.getValue());
269 if (!predicate)
270 return parser.emitError(loc: predicateLoc)
271 << "'" << predicateAttr.getValue()
272 << "' is an incorrect value of the 'predicate' attribute";
273 predicateValue = static_cast<int64_t>(*predicate);
274 } else {
275 std::optional<FCmpPredicate> predicate =
276 symbolizeFCmpPredicate(predicateAttr.getValue());
277 if (!predicate)
278 return parser.emitError(loc: predicateLoc)
279 << "'" << predicateAttr.getValue()
280 << "' is an incorrect value of the 'predicate' attribute";
281 predicateValue = static_cast<int64_t>(*predicate);
282 }
283
284 result.attributes.set(name: "predicate",
285 value: parser.getBuilder().getI64IntegerAttr(value: predicateValue));
286
287 // The result type is either i1 or a vector type <? x i1> if the inputs are
288 // vectors.
289 if (!isCompatibleType(type))
290 return parser.emitError(loc: trailingTypeLoc,
291 message: "expected LLVM dialect-compatible type");
292 result.addTypes(newTypes: getI1SameShape(type));
293 return success();
294}
295
296ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
297 return parseCmpOp<ICmpPredicate>(parser, result);
298}
299
300ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
301 return parseCmpOp<FCmpPredicate>(parser, result);
302}
303
304/// Returns a scalar or vector boolean attribute of the given type.
305static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
306 auto boolAttr = BoolAttr::get(context: ctx, value);
307 ShapedType shapedType = dyn_cast<ShapedType>(Val&: type);
308 if (!shapedType)
309 return boolAttr;
310 return DenseElementsAttr::get(type: shapedType, values: boolAttr);
311}
312
313OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
314 if (getPredicate() != ICmpPredicate::eq &&
315 getPredicate() != ICmpPredicate::ne)
316 return {};
317
318 // cmpi(eq/ne, x, x) -> true/false
319 if (getLhs() == getRhs())
320 return getBoolAttribute(type: getType(), ctx: getContext(),
321 value: getPredicate() == ICmpPredicate::eq);
322
323 // cmpi(eq/ne, alloca, null) -> false/true
324 if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
325 return getBoolAttribute(type: getType(), ctx: getContext(),
326 value: getPredicate() == ICmpPredicate::ne);
327
328 // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
329 if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
330 Value lhs = getLhs();
331 Value rhs = getRhs();
332 getLhsMutable().assign(value: rhs);
333 getRhsMutable().assign(value: lhs);
334 return getResult();
335 }
336
337 return {};
338}
339
340//===----------------------------------------------------------------------===//
341// Printing, parsing and verification for LLVM::AllocaOp.
342//===----------------------------------------------------------------------===//
343
344void AllocaOp::print(OpAsmPrinter &p) {
345 auto funcTy =
346 FunctionType::get(context: getContext(), inputs: {getArraySize().getType()}, results: {getType()});
347
348 if (getInalloca())
349 p << " inalloca";
350
351 p << ' ' << getArraySize() << " x " << getElemType();
352 if (getAlignment() && *getAlignment() != 0)
353 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
354 elidedAttrs: {kElemTypeAttrName, getInallocaAttrName()});
355 else
356 p.printOptionalAttrDict(
357 attrs: (*this)->getAttrs(),
358 elidedAttrs: {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
359 p << " : " << funcTy;
360}
361
362// <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
363// attribute-dict? `:` type `,` type
364ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
365 OpAsmParser::UnresolvedOperand arraySize;
366 Type type, elemType;
367 SMLoc trailingTypeLoc;
368
369 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "inalloca")))
370 result.addAttribute(name: getInallocaAttrName(name: result.name),
371 attr: UnitAttr::get(context: parser.getContext()));
372
373 if (parser.parseOperand(result&: arraySize) || parser.parseKeyword(keyword: "x") ||
374 parser.parseType(result&: elemType) ||
375 parser.parseOptionalAttrDict(result&: result.attributes) || parser.parseColon() ||
376 parser.getCurrentLocation(loc: &trailingTypeLoc) || parser.parseType(result&: type))
377 return failure();
378
379 std::optional<NamedAttribute> alignmentAttr =
380 result.attributes.getNamed(name: "alignment");
381 if (alignmentAttr.has_value()) {
382 auto alignmentInt = llvm::dyn_cast<IntegerAttr>(Val: alignmentAttr->getValue());
383 if (!alignmentInt)
384 return parser.emitError(loc: parser.getNameLoc(),
385 message: "expected integer alignment");
386 if (alignmentInt.getValue().isZero())
387 result.attributes.erase(name: "alignment");
388 }
389
390 // Extract the result type from the trailing function type.
391 auto funcType = llvm::dyn_cast<FunctionType>(Val&: type);
392 if (!funcType || funcType.getNumInputs() != 1 ||
393 funcType.getNumResults() != 1)
394 return parser.emitError(
395 loc: trailingTypeLoc,
396 message: "expected trailing function type with one argument and one result");
397
398 if (parser.resolveOperand(operand: arraySize, type: funcType.getInput(i: 0), result&: result.operands))
399 return failure();
400
401 Type resultType = funcType.getResult(i: 0);
402 if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(Val&: resultType))
403 result.addAttribute(name: kElemTypeAttrName, attr: TypeAttr::get(type: elemType));
404
405 result.addTypes(newTypes: {funcType.getResult(i: 0)});
406 return success();
407}
408
409LogicalResult AllocaOp::verify() {
410 // Only certain target extension types can be used in 'alloca'.
411 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(Val: getElemType());
412 targetExtType && !targetExtType.supportsMemOps())
413 return emitOpError()
414 << "this target extension type cannot be used in alloca";
415
416 return success();
417}
418
419//===----------------------------------------------------------------------===//
420// LLVM::BrOp
421//===----------------------------------------------------------------------===//
422
423SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
424 assert(index == 0 && "invalid successor index");
425 return SuccessorOperands(getDestOperandsMutable());
426}
427
428//===----------------------------------------------------------------------===//
429// LLVM::CondBrOp
430//===----------------------------------------------------------------------===//
431
432SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
433 assert(index < getNumSuccessors() && "invalid successor index");
434 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
435 : getFalseDestOperandsMutable());
436}
437
438void CondBrOp::build(OpBuilder &builder, OperationState &result,
439 Value condition, Block *trueDest, ValueRange trueOperands,
440 Block *falseDest, ValueRange falseOperands,
441 std::optional<std::pair<uint32_t, uint32_t>> weights) {
442 DenseI32ArrayAttr weightsAttr;
443 if (weights)
444 weightsAttr =
445 builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(weights->first),
446 static_cast<int32_t>(weights->second)});
447
448 build(odsBuilder&: builder, odsState&: result, condition, trueDestOperands: trueOperands, falseDestOperands: falseOperands, branch_weights: weightsAttr,
449 /*loop_annotation=*/{}, trueDest, falseDest);
450}
451
452//===----------------------------------------------------------------------===//
453// LLVM::SwitchOp
454//===----------------------------------------------------------------------===//
455
456void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
457 Block *defaultDestination, ValueRange defaultOperands,
458 DenseIntElementsAttr caseValues,
459 BlockRange caseDestinations,
460 ArrayRef<ValueRange> caseOperands,
461 ArrayRef<int32_t> branchWeights) {
462 DenseI32ArrayAttr weightsAttr;
463 if (!branchWeights.empty())
464 weightsAttr = builder.getDenseI32ArrayAttr(values: branchWeights);
465
466 build(odsBuilder&: builder, odsState&: result, value, defaultOperands, caseOperands, case_values: caseValues,
467 branch_weights: weightsAttr, defaultDestination, caseDestinations);
468}
469
470void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
471 Block *defaultDestination, ValueRange defaultOperands,
472 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
473 ArrayRef<ValueRange> caseOperands,
474 ArrayRef<int32_t> branchWeights) {
475 DenseIntElementsAttr caseValuesAttr;
476 if (!caseValues.empty()) {
477 ShapedType caseValueType = VectorType::get(
478 shape: static_cast<int64_t>(caseValues.size()), elementType: value.getType());
479 caseValuesAttr = DenseIntElementsAttr::get(type: caseValueType, arg&: caseValues);
480 }
481
482 build(builder, result, value, defaultDestination, defaultOperands,
483 caseValues: caseValuesAttr, caseDestinations, caseOperands, branchWeights);
484}
485
486void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
487 Block *defaultDestination, ValueRange defaultOperands,
488 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
489 ArrayRef<ValueRange> caseOperands,
490 ArrayRef<int32_t> branchWeights) {
491 DenseIntElementsAttr caseValuesAttr;
492 if (!caseValues.empty()) {
493 ShapedType caseValueType = VectorType::get(
494 shape: static_cast<int64_t>(caseValues.size()), elementType: value.getType());
495 caseValuesAttr = DenseIntElementsAttr::get(type: caseValueType, arg&: caseValues);
496 }
497
498 build(builder, result, value, defaultDestination, defaultOperands,
499 caseValues: caseValuesAttr, caseDestinations, caseOperands, branchWeights);
500}
501
502/// <cases> ::= `[` (case (`,` case )* )? `]`
503/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
504static ParseResult parseSwitchOpCases(
505 OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
506 SmallVectorImpl<Block *> &caseDestinations,
507 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
508 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
509 if (failed(Result: parser.parseLSquare()))
510 return failure();
511 if (succeeded(Result: parser.parseOptionalRSquare()))
512 return success();
513 SmallVector<APInt> values;
514 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
515 auto parseCase = [&]() {
516 int64_t value = 0;
517 if (failed(Result: parser.parseInteger(result&: value)))
518 return failure();
519 values.push_back(Elt: APInt(bitWidth, value, /*isSigned=*/true));
520
521 Block *destination;
522 SmallVector<OpAsmParser::UnresolvedOperand> operands;
523 SmallVector<Type> operandTypes;
524 if (parser.parseColon() || parser.parseSuccessor(dest&: destination))
525 return failure();
526 if (!parser.parseOptionalLParen()) {
527 if (parser.parseOperandList(result&: operands, delimiter: OpAsmParser::Delimiter::None,
528 /*allowResultNumber=*/false) ||
529 parser.parseColonTypeList(result&: operandTypes) || parser.parseRParen())
530 return failure();
531 }
532 caseDestinations.push_back(Elt: destination);
533 caseOperands.emplace_back(Args&: operands);
534 caseOperandTypes.emplace_back(Args&: operandTypes);
535 return success();
536 };
537 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: parseCase)))
538 return failure();
539
540 ShapedType caseValueType =
541 VectorType::get(shape: static_cast<int64_t>(values.size()), elementType: flagType);
542 caseValues = DenseIntElementsAttr::get(type: caseValueType, arg&: values);
543 return parser.parseRSquare();
544}
545
546static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
547 DenseIntElementsAttr caseValues,
548 SuccessorRange caseDestinations,
549 OperandRangeRange caseOperands,
550 const TypeRangeRange &caseOperandTypes) {
551 p << '[';
552 p.printNewline();
553 if (!caseValues) {
554 p << ']';
555 return;
556 }
557
558 size_t index = 0;
559 llvm::interleave(
560 c: llvm::zip(t&: caseValues, u&: caseDestinations),
561 each_fn: [&](auto i) {
562 p << " ";
563 p << std::get<0>(i);
564 p << ": ";
565 p.printSuccessorAndUseList(successor: std::get<1>(i), succOperands: caseOperands[index++]);
566 },
567 between_fn: [&] {
568 p << ',';
569 p.printNewline();
570 });
571 p.printNewline();
572 p << ']';
573}
574
575LogicalResult SwitchOp::verify() {
576 if ((!getCaseValues() && !getCaseDestinations().empty()) ||
577 (getCaseValues() &&
578 getCaseValues()->size() !=
579 static_cast<int64_t>(getCaseDestinations().size())))
580 return emitOpError(message: "expects number of case values to match number of "
581 "case destinations");
582 if (getCaseValues() &&
583 getValue().getType() != getCaseValues()->getElementType())
584 return emitError(message: "expects case value type to match condition value type");
585 return success();
586}
587
588SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
589 assert(index < getNumSuccessors() && "invalid successor index");
590 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
591 : getCaseOperandsMutable(index: index - 1));
592}
593
594//===----------------------------------------------------------------------===//
595// Code for LLVM::GEPOp.
596//===----------------------------------------------------------------------===//
597
598constexpr int32_t GEPOp::kDynamicIndex;
599
600GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
601 return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
602 getDynamicIndices());
603}
604
605/// Returns the elemental type of any LLVM-compatible vector type or self.
606static Type extractVectorElementType(Type type) {
607 if (auto vectorType = llvm::dyn_cast<VectorType>(Val&: type))
608 return vectorType.getElementType();
609 return type;
610}
611
612/// Destructures the 'indices' parameter into 'rawConstantIndices' and
613/// 'dynamicIndices', encoding the former in the process. In the process,
614/// dynamic indices which are used to index into a structure type are converted
615/// to constant indices when possible. To do this, the GEPs element type should
616/// be passed as first parameter.
617static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
618 SmallVectorImpl<int32_t> &rawConstantIndices,
619 SmallVectorImpl<Value> &dynamicIndices) {
620 for (const GEPArg &iter : indices) {
621 // If the thing we are currently indexing into is a struct we must turn
622 // any integer constants into constant indices. If this is not possible
623 // we don't do anything here. The verifier will catch it and emit a proper
624 // error. All other canonicalization is done in the fold method.
625 bool requiresConst = !rawConstantIndices.empty() &&
626 isa_and_nonnull<LLVMStructType>(Val: currType);
627 if (Value val = llvm::dyn_cast_if_present<Value>(Val: iter)) {
628 APInt intC;
629 if (requiresConst && matchPattern(value: val, pattern: m_ConstantInt(bind_value: &intC)) &&
630 intC.isSignedIntN(N: kGEPConstantBitWidth)) {
631 rawConstantIndices.push_back(Elt: intC.getSExtValue());
632 } else {
633 rawConstantIndices.push_back(Elt: GEPOp::kDynamicIndex);
634 dynamicIndices.push_back(Elt: val);
635 }
636 } else {
637 rawConstantIndices.push_back(Elt: cast<GEPConstantIndex>(Val: iter));
638 }
639
640 // Skip for very first iteration of this loop. First index does not index
641 // within the aggregates, but is just a pointer offset.
642 if (rawConstantIndices.size() == 1 || !currType)
643 continue;
644
645 currType = TypeSwitch<Type, Type>(currType)
646 .Case<VectorType, LLVMArrayType>(caseFn: [](auto containerType) {
647 return containerType.getElementType();
648 })
649 .Case(caseFn: [&](LLVMStructType structType) -> Type {
650 int64_t memberIndex = rawConstantIndices.back();
651 if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
652 structType.getBody().size())
653 return structType.getBody()[memberIndex];
654 return nullptr;
655 })
656 .Default(defaultResult: Type(nullptr));
657 }
658}
659
660void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
661 Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
662 GEPNoWrapFlags noWrapFlags,
663 ArrayRef<NamedAttribute> attributes) {
664 SmallVector<int32_t> rawConstantIndices;
665 SmallVector<Value> dynamicIndices;
666 destructureIndices(currType: elementType, indices, rawConstantIndices, dynamicIndices);
667
668 result.addTypes(newTypes: resultType);
669 result.addAttributes(newAttributes: attributes);
670 result.getOrAddProperties<Properties>().rawConstantIndices =
671 builder.getDenseI32ArrayAttr(values: rawConstantIndices);
672 result.getOrAddProperties<Properties>().noWrapFlags = noWrapFlags;
673 result.getOrAddProperties<Properties>().elem_type =
674 TypeAttr::get(type: elementType);
675 result.addOperands(newOperands: basePtr);
676 result.addOperands(newOperands: dynamicIndices);
677}
678
679void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
680 Type elementType, Value basePtr, ValueRange indices,
681 GEPNoWrapFlags noWrapFlags,
682 ArrayRef<NamedAttribute> attributes) {
683 build(builder, result, resultType, elementType, basePtr,
684 indices: SmallVector<GEPArg>(indices), noWrapFlags, attributes);
685}
686
687static ParseResult
688parseGEPIndices(OpAsmParser &parser,
689 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
690 DenseI32ArrayAttr &rawConstantIndices) {
691 SmallVector<int32_t> constantIndices;
692
693 auto idxParser = [&]() -> ParseResult {
694 int32_t constantIndex;
695 OptionalParseResult parsedInteger =
696 parser.parseOptionalInteger(result&: constantIndex);
697 if (parsedInteger.has_value()) {
698 if (failed(Result: parsedInteger.value()))
699 return failure();
700 constantIndices.push_back(Elt: constantIndex);
701 return success();
702 }
703
704 constantIndices.push_back(Elt: LLVM::GEPOp::kDynamicIndex);
705 return parser.parseOperand(result&: indices.emplace_back());
706 };
707 if (parser.parseCommaSeparatedList(parseElementFn: idxParser))
708 return failure();
709
710 rawConstantIndices =
711 DenseI32ArrayAttr::get(context: parser.getContext(), content: constantIndices);
712 return success();
713}
714
715static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
716 OperandRange indices,
717 DenseI32ArrayAttr rawConstantIndices) {
718 llvm::interleaveComma(
719 c: GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), os&: printer,
720 each_fn: [&](PointerUnion<IntegerAttr, Value> cst) {
721 if (Value val = llvm::dyn_cast_if_present<Value>(Val&: cst))
722 printer.printOperand(value: val);
723 else
724 printer << cast<IntegerAttr>(Val&: cst).getInt();
725 });
726}
727
728/// For the given `indices`, check if they comply with `baseGEPType`,
729/// especially check against LLVMStructTypes nested within.
730static LogicalResult
731verifyStructIndices(Type baseGEPType, unsigned indexPos,
732 GEPIndicesAdaptor<ValueRange> indices,
733 function_ref<InFlightDiagnostic()> emitOpError) {
734 if (indexPos >= indices.size())
735 // Stop searching
736 return success();
737
738 return TypeSwitch<Type, LogicalResult>(baseGEPType)
739 .Case<LLVMStructType>(caseFn: [&](LLVMStructType structType) -> LogicalResult {
740 auto attr = dyn_cast<IntegerAttr>(Val: indices[indexPos]);
741 if (!attr)
742 return emitOpError() << "expected index " << indexPos
743 << " indexing a struct to be constant";
744
745 int32_t gepIndex = attr.getInt();
746 ArrayRef<Type> elementTypes = structType.getBody();
747 if (gepIndex < 0 ||
748 static_cast<size_t>(gepIndex) >= elementTypes.size())
749 return emitOpError() << "index " << indexPos
750 << " indexing a struct is out of bounds";
751
752 // Instead of recursively going into every children types, we only
753 // dive into the one indexed by gepIndex.
754 return verifyStructIndices(baseGEPType: elementTypes[gepIndex], indexPos: indexPos + 1,
755 indices, emitOpError);
756 })
757 .Case<VectorType, LLVMArrayType>(
758 caseFn: [&](auto containerType) -> LogicalResult {
759 return verifyStructIndices(containerType.getElementType(),
760 indexPos + 1, indices, emitOpError);
761 })
762 .Default(defaultFn: [&](auto otherType) -> LogicalResult {
763 return emitOpError()
764 << "type " << otherType << " cannot be indexed (index #"
765 << indexPos << ")";
766 });
767}
768
769/// Driver function around `verifyStructIndices`.
770static LogicalResult
771verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
772 function_ref<InFlightDiagnostic()> emitOpError) {
773 return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
774}
775
776LogicalResult LLVM::GEPOp::verify() {
777 if (static_cast<size_t>(
778 llvm::count(Range: getRawConstantIndices(), Element: kDynamicIndex)) !=
779 getDynamicIndices().size())
780 return emitOpError(message: "expected as many dynamic indices as specified in '")
781 << getRawConstantIndicesAttrName().getValue() << "'";
782
783 if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag)
784 return emitOpError(message: "'inbounds_flag' cannot be used directly.");
785
786 return verifyStructIndices(baseGEPType: getElemType(), indices: getIndices(),
787 emitOpError: [&] { return emitOpError(); });
788}
789
790//===----------------------------------------------------------------------===//
791// LoadOp
792//===----------------------------------------------------------------------===//
793
794void LoadOp::getEffects(
795 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
796 &effects) {
797 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getAddrMutable());
798 // Volatile operations can have target-specific read-write effects on
799 // memory besides the one referred to by the pointer operand.
800 // Similarly, atomic operations that are monotonic or stricter cause
801 // synchronization that from a language point-of-view, are arbitrary
802 // read-writes into memory.
803 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
804 getOrdering() != AtomicOrdering::unordered)) {
805 effects.emplace_back(Args: MemoryEffects::Write::get());
806 effects.emplace_back(Args: MemoryEffects::Read::get());
807 }
808}
809
810/// Returns true if the given type is supported by atomic operations. All
811/// integer, float, and pointer types with a power-of-two bitsize and a minimal
812/// size of 8 bits are supported.
813static bool isTypeCompatibleWithAtomicOp(Type type,
814 const DataLayout &dataLayout) {
815 if (!isa<IntegerType, LLVMPointerType>(Val: type))
816 if (!isCompatibleFloatingPointType(type))
817 return false;
818
819 llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(t: type);
820 if (bitWidth.isScalable())
821 return false;
822 // Needs to be at least 8 bits and a power of two.
823 return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0;
824}
825
826/// Verifies the attributes and the type of atomic memory access operations.
827template <typename OpTy>
828LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
829 ArrayRef<AtomicOrdering> unsupportedOrderings) {
830 if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
831 DataLayout dataLayout = DataLayout::closest(op: memOp);
832 if (!isTypeCompatibleWithAtomicOp(type: valueType, dataLayout))
833 return memOp.emitOpError("unsupported type ")
834 << valueType << " for atomic access";
835 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
836 return memOp.emitOpError("unsupported ordering '")
837 << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
838 if (!memOp.getAlignment())
839 return memOp.emitOpError("expected alignment for atomic access");
840 return success();
841 }
842 if (memOp.getSyncscope())
843 return memOp.emitOpError(
844 "expected syncscope to be null for non-atomic access");
845 return success();
846}
847
848LogicalResult LoadOp::verify() {
849 Type valueType = getResult().getType();
850 return verifyAtomicMemOp(memOp: *this, valueType,
851 unsupportedOrderings: {AtomicOrdering::release, AtomicOrdering::acq_rel});
852}
853
854void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
855 Value addr, unsigned alignment, bool isVolatile,
856 bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
857 AtomicOrdering ordering, StringRef syncscope) {
858 build(odsBuilder&: builder, odsState&: state, res: type, addr,
859 alignment: alignment ? builder.getI64IntegerAttr(value: alignment) : nullptr, volatile_: isVolatile,
860 nontemporal: isNonTemporal, invariant: isInvariant, invariantGroup: isInvariantGroup, ordering,
861 syncscope: syncscope.empty() ? nullptr : builder.getStringAttr(bytes: syncscope),
862 /*dereferenceable=*/nullptr,
863 /*access_groups=*/nullptr,
864 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
865 /*tbaa=*/nullptr);
866}
867
868//===----------------------------------------------------------------------===//
869// StoreOp
870//===----------------------------------------------------------------------===//
871
872void StoreOp::getEffects(
873 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
874 &effects) {
875 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: &getAddrMutable());
876 // Volatile operations can have target-specific read-write effects on
877 // memory besides the one referred to by the pointer operand.
878 // Similarly, atomic operations that are monotonic or stricter cause
879 // synchronization that from a language point-of-view, are arbitrary
880 // read-writes into memory.
881 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
882 getOrdering() != AtomicOrdering::unordered)) {
883 effects.emplace_back(Args: MemoryEffects::Write::get());
884 effects.emplace_back(Args: MemoryEffects::Read::get());
885 }
886}
887
888LogicalResult StoreOp::verify() {
889 Type valueType = getValue().getType();
890 return verifyAtomicMemOp(memOp: *this, valueType,
891 unsupportedOrderings: {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
892}
893
894void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
895 Value addr, unsigned alignment, bool isVolatile,
896 bool isNonTemporal, bool isInvariantGroup,
897 AtomicOrdering ordering, StringRef syncscope) {
898 build(odsBuilder&: builder, odsState&: state, value, addr,
899 alignment: alignment ? builder.getI64IntegerAttr(value: alignment) : nullptr, volatile_: isVolatile,
900 nontemporal: isNonTemporal, invariantGroup: isInvariantGroup, ordering,
901 syncscope: syncscope.empty() ? nullptr : builder.getStringAttr(bytes: syncscope),
902 /*access_groups=*/nullptr,
903 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
904}
905
906//===----------------------------------------------------------------------===//
907// CallOp
908//===----------------------------------------------------------------------===//
909
910/// Gets the MLIR Op-like result types of a LLVMFunctionType.
911static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
912 SmallVector<Type, 1> results;
913 Type resultType = calleeType.getReturnType();
914 if (!isa<LLVM::LLVMVoidType>(Val: resultType))
915 results.push_back(Elt: resultType);
916 return results;
917}
918
919/// Gets the variadic callee type for a LLVMFunctionType.
920static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
921 return calleeType.isVarArg() ? TypeAttr::get(type: calleeType) : nullptr;
922}
923
924/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
925static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
926 ValueRange args) {
927 Type resultType;
928 if (results.empty())
929 resultType = LLVMVoidType::get(ctx: context);
930 else
931 resultType = results.front();
932 return LLVMFunctionType::get(result: resultType, arguments: llvm::to_vector(Range: args.getTypes()),
933 /*isVarArg=*/false);
934}
935
936void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
937 StringRef callee, ValueRange args) {
938 build(odsBuilder&: builder, odsState&: state, results, callee: builder.getStringAttr(bytes: callee), args);
939}
940
941void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
942 StringAttr callee, ValueRange args) {
943 build(odsBuilder&: builder, odsState&: state, results, callee: SymbolRefAttr::get(value: callee), args);
944}
945
946void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
947 FlatSymbolRefAttr callee, ValueRange args) {
948 assert(callee && "expected non-null callee in direct call builder");
949 build(odsBuilder&: builder, odsState&: state, resultTypes: results,
950 /*var_callee_type=*/nullptr, callee, callee_operands: args, /*fastmathFlags=*/nullptr,
951 /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
952 /*memory_effects=*/nullptr,
953 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
954 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
955 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
956 /*access_groups=*/no_inline: nullptr, /*alias_scopes=*/always_inline: nullptr,
957 /*noalias_scopes=*/inline_hint: nullptr, /*tbaa=*/access_groups: nullptr,
958 /*no_inline=*/alias_scopes: nullptr, /*always_inline=*/noalias_scopes: nullptr,
959 /*inline_hint=*/tbaa: nullptr);
960}
961
962void CallOp::build(OpBuilder &builder, OperationState &state,
963 LLVMFunctionType calleeType, StringRef callee,
964 ValueRange args) {
965 build(odsBuilder&: builder, odsState&: state, calleeType, callee: builder.getStringAttr(bytes: callee), args);
966}
967
968void CallOp::build(OpBuilder &builder, OperationState &state,
969 LLVMFunctionType calleeType, StringAttr callee,
970 ValueRange args) {
971 build(odsBuilder&: builder, odsState&: state, calleeType, callee: SymbolRefAttr::get(value: callee), args);
972}
973
974void CallOp::build(OpBuilder &builder, OperationState &state,
975 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
976 ValueRange args) {
977 build(odsBuilder&: builder, odsState&: state, resultTypes: getCallOpResultTypes(calleeType),
978 var_callee_type: getCallOpVarCalleeType(calleeType), callee, callee_operands: args,
979 /*fastmathFlags=*/nullptr,
980 /*CConv=*/nullptr,
981 /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
982 /*convergent=*/nullptr,
983 /*no_unwind=*/nullptr, /*will_return=*/nullptr,
984 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
985 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
986 /*access_groups=*/no_inline: nullptr,
987 /*alias_scopes=*/always_inline: nullptr, /*noalias_scopes=*/inline_hint: nullptr, /*tbaa=*/access_groups: nullptr,
988 /*no_inline=*/alias_scopes: nullptr, /*always_inline=*/noalias_scopes: nullptr,
989 /*inline_hint=*/tbaa: nullptr);
990}
991
992void CallOp::build(OpBuilder &builder, OperationState &state,
993 LLVMFunctionType calleeType, ValueRange args) {
994 build(odsBuilder&: builder, odsState&: state, resultTypes: getCallOpResultTypes(calleeType),
995 var_callee_type: getCallOpVarCalleeType(calleeType),
996 /*callee=*/nullptr, callee_operands: args,
997 /*fastmathFlags=*/nullptr,
998 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
999 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1000 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1001 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1002 /*access_groups=*/no_inline: nullptr, /*alias_scopes=*/always_inline: nullptr,
1003 /*noalias_scopes=*/inline_hint: nullptr, /*tbaa=*/access_groups: nullptr,
1004 /*no_inline=*/alias_scopes: nullptr, /*always_inline=*/noalias_scopes: nullptr,
1005 /*inline_hint=*/tbaa: nullptr);
1006}
1007
1008void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1009 ValueRange args) {
1010 auto calleeType = func.getFunctionType();
1011 build(odsBuilder&: builder, odsState&: state, resultTypes: getCallOpResultTypes(calleeType),
1012 var_callee_type: getCallOpVarCalleeType(calleeType), callee: SymbolRefAttr::get(symbol: func), callee_operands: args,
1013 /*fastmathFlags=*/nullptr,
1014 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1015 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1016 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1017 /*access_groups=*/arg_attrs: nullptr, /*alias_scopes=*/res_attrs: nullptr,
1018 /*arg_attrs=*/no_inline: nullptr, /*res_attrs=*/always_inline: nullptr,
1019 /*noalias_scopes=*/inline_hint: nullptr, /*tbaa=*/access_groups: nullptr,
1020 /*no_inline=*/alias_scopes: nullptr, /*always_inline=*/noalias_scopes: nullptr,
1021 /*inline_hint=*/tbaa: nullptr);
1022}
1023
1024CallInterfaceCallable CallOp::getCallableForCallee() {
1025 // Direct call.
1026 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1027 return calleeAttr;
1028 // Indirect call, callee Value is the first operand.
1029 return getOperand(i: 0);
1030}
1031
1032void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1033 // Direct call.
1034 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1035 auto symRef = cast<SymbolRefAttr>(Val&: callee);
1036 return setCalleeAttr(cast<FlatSymbolRefAttr>(Val&: symRef));
1037 }
1038 // Indirect call, callee Value is the first operand.
1039 return setOperand(i: 0, value: cast<Value>(Val&: callee));
1040}
1041
1042Operation::operand_range CallOp::getArgOperands() {
1043 return getCalleeOperands().drop_front(n: getCallee().has_value() ? 0 : 1);
1044}
1045
1046MutableOperandRange CallOp::getArgOperandsMutable() {
1047 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1048 getCalleeOperands().size());
1049}
1050
1051/// Verify that an inlinable callsite of a debug-info-bearing function in a
1052/// debug-info-bearing function has a debug location attached to it. This
1053/// mirrors an LLVM IR verifier.
1054static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1055 if (callee.isExternal())
1056 return success();
1057 auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
1058 if (!parentFunc)
1059 return success();
1060
1061 auto hasSubprogram = [](Operation *op) {
1062 return op->getLoc()
1063 ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1064 nullptr;
1065 };
1066 if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1067 return success();
1068 bool containsLoc = !isa<UnknownLoc>(Val: callOp->getLoc());
1069 if (!containsLoc)
1070 return callOp.emitError()
1071 << "inlinable function call in a function with a DISubprogram "
1072 "location must have a debug location";
1073 return success();
1074}
1075
1076/// Verify that the parameter and return types of the variadic callee type match
1077/// the `callOp` argument and result types.
1078template <typename OpTy>
1079LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
1080 std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
1081 if (!varCalleeType)
1082 return success();
1083
1084 // Verify the variadic callee type is a variadic function type.
1085 if (!varCalleeType->isVarArg())
1086 return callOp.emitOpError(
1087 "expected var_callee_type to be a variadic function type");
1088
1089 // Verify the variadic callee type has at most as many parameters as the call
1090 // has argument operands.
1091 if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
1092 return callOp.emitOpError("expected var_callee_type to have at most ")
1093 << callOp.getArgOperands().size() << " parameters";
1094
1095 // Verify the variadic callee type matches the call argument types.
1096 for (auto [paramType, operand] :
1097 llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
1098 if (paramType != operand.getType())
1099 return callOp.emitOpError()
1100 << "var_callee_type parameter type mismatch: " << paramType
1101 << " != " << operand.getType();
1102
1103 // Verify the variadic callee type matches the call result type.
1104 if (!callOp.getNumResults()) {
1105 if (!isa<LLVMVoidType>(Val: varCalleeType->getReturnType()))
1106 return callOp.emitOpError("expected var_callee_type to return void");
1107 } else {
1108 if (callOp.getResult().getType() != varCalleeType->getReturnType())
1109 return callOp.emitOpError("var_callee_type return type mismatch: ")
1110 << varCalleeType->getReturnType()
1111 << " != " << callOp.getResult().getType();
1112 }
1113 return success();
1114}
1115
1116template <typename OpType>
1117static LogicalResult verifyOperandBundles(OpType &op) {
1118 OperandRangeRange opBundleOperands = op.getOpBundleOperands();
1119 std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
1120
1121 auto isStringAttr = [](Attribute tagAttr) {
1122 return isa<StringAttr>(Val: tagAttr);
1123 };
1124 if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
1125 return op.emitError("operand bundle tag must be a StringAttr");
1126
1127 size_t numOpBundles = opBundleOperands.size();
1128 size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
1129 if (numOpBundles != numOpBundleTags)
1130 return op.emitError("expected ")
1131 << numOpBundles << " operand bundle tags, but actually got "
1132 << numOpBundleTags;
1133
1134 return success();
1135}
1136
1137LogicalResult CallOp::verify() { return verifyOperandBundles(op&: *this); }
1138
1139LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1140 if (failed(Result: verifyCallOpVarCalleeType(callOp: *this)))
1141 return failure();
1142
1143 // Type for the callee, we'll get it differently depending if it is a direct
1144 // or indirect call.
1145 Type fnType;
1146
1147 bool isIndirect = false;
1148
1149 // If this is an indirect call, the callee attribute is missing.
1150 FlatSymbolRefAttr calleeName = getCalleeAttr();
1151 if (!calleeName) {
1152 isIndirect = true;
1153 if (!getNumOperands())
1154 return emitOpError(
1155 message: "must have either a `callee` attribute or at least an operand");
1156 auto ptrType = llvm::dyn_cast<LLVMPointerType>(Val: getOperand(i: 0).getType());
1157 if (!ptrType)
1158 return emitOpError(message: "indirect call expects a pointer as callee: ")
1159 << getOperand(i: 0).getType();
1160
1161 return success();
1162 } else {
1163 Operation *callee =
1164 symbolTable.lookupNearestSymbolFrom(from: *this, symbol: calleeName.getAttr());
1165 if (!callee)
1166 return emitOpError()
1167 << "'" << calleeName.getValue()
1168 << "' does not reference a symbol in the current scope";
1169 auto fn = dyn_cast<LLVMFuncOp>(Val: callee);
1170 if (!fn)
1171 return emitOpError() << "'" << calleeName.getValue()
1172 << "' does not reference a valid LLVM function";
1173
1174 if (failed(Result: verifyCallOpDebugInfo(callOp: *this, callee: fn)))
1175 return failure();
1176 fnType = fn.getFunctionType();
1177 }
1178
1179 LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(Val&: fnType);
1180 if (!funcType)
1181 return emitOpError(message: "callee does not have a functional type: ") << fnType;
1182
1183 if (funcType.isVarArg() && !getVarCalleeType())
1184 return emitOpError() << "missing var_callee_type attribute for vararg call";
1185
1186 // Verify that the operand and result types match the callee.
1187
1188 if (!funcType.isVarArg() &&
1189 funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
1190 return emitOpError() << "incorrect number of operands ("
1191 << (getCalleeOperands().size() - isIndirect)
1192 << ") for callee (expecting: "
1193 << funcType.getNumParams() << ")";
1194
1195 if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
1196 return emitOpError() << "incorrect number of operands ("
1197 << (getCalleeOperands().size() - isIndirect)
1198 << ") for varargs callee (expecting at least: "
1199 << funcType.getNumParams() << ")";
1200
1201 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1202 if (getOperand(i: i + isIndirect).getType() != funcType.getParamType(i))
1203 return emitOpError() << "operand type mismatch for operand " << i << ": "
1204 << getOperand(i: i + isIndirect).getType()
1205 << " != " << funcType.getParamType(i);
1206
1207 if (getNumResults() == 0 &&
1208 !llvm::isa<LLVM::LLVMVoidType>(Val: funcType.getReturnType()))
1209 return emitOpError() << "expected function call to produce a value";
1210
1211 if (getNumResults() != 0 &&
1212 llvm::isa<LLVM::LLVMVoidType>(Val: funcType.getReturnType()))
1213 return emitOpError()
1214 << "calling function with void result must not produce values";
1215
1216 if (getNumResults() > 1)
1217 return emitOpError()
1218 << "expected LLVM function call to produce 0 or 1 result";
1219
1220 if (getNumResults() && getResult().getType() != funcType.getReturnType())
1221 return emitOpError() << "result type mismatch: " << getResult().getType()
1222 << " != " << funcType.getReturnType();
1223
1224 return success();
1225}
1226
1227void CallOp::print(OpAsmPrinter &p) {
1228 auto callee = getCallee();
1229 bool isDirect = callee.has_value();
1230
1231 p << ' ';
1232
1233 // Print calling convention.
1234 if (getCConv() != LLVM::CConv::C)
1235 p << stringifyCConv(getCConv()) << ' ';
1236
1237 if (getTailCallKind() != LLVM::TailCallKind::None)
1238 p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
1239
1240 // Print the direct callee if present as a function attribute, or an indirect
1241 // callee (first operand) otherwise.
1242 if (isDirect)
1243 p.printSymbolName(symbolRef: callee.value());
1244 else
1245 p << getOperand(i: 0);
1246
1247 auto args = getCalleeOperands().drop_front(n: isDirect ? 0 : 1);
1248 p << '(' << args << ')';
1249
1250 // Print the variadic callee type if the call is variadic.
1251 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1252 p << " vararg(" << *varCalleeType << ")";
1253
1254 if (!getOpBundleOperands().empty()) {
1255 p << " ";
1256 printOpBundles(p, op: *this, opBundleOperands: getOpBundleOperands(),
1257 opBundleOperandTypes: getOpBundleOperands().getTypes(), opBundleTags: getOpBundleTags());
1258 }
1259
1260 p.printOptionalAttrDict(attrs: processFMFAttr(attrs: (*this)->getAttrs()),
1261 elidedAttrs: {getCalleeAttrName(), getTailCallKindAttrName(),
1262 getVarCalleeTypeAttrName(), getCConvAttrName(),
1263 getOperandSegmentSizesAttrName(),
1264 getOpBundleSizesAttrName(),
1265 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1266 getResAttrsAttrName()});
1267
1268 p << " : ";
1269 if (!isDirect)
1270 p << getOperand(i: 0).getType() << ", ";
1271
1272 // Reconstruct the MLIR function type from operand and result types.
1273 call_interface_impl::printFunctionSignature(
1274 p, argTypes: args.getTypes(), argAttrs: getArgAttrsAttr(),
1275 /*isVariadic=*/false, resultTypes: getResultTypes(), resultAttrs: getResAttrsAttr());
1276}
1277
1278/// Parses the type of a call operation and resolves the operands if the parsing
1279/// succeeds. Returns failure otherwise.
1280static ParseResult parseCallTypeAndResolveOperands(
1281 OpAsmParser &parser, OperationState &result, bool isDirect,
1282 ArrayRef<OpAsmParser::UnresolvedOperand> operands,
1283 SmallVectorImpl<DictionaryAttr> &argAttrs,
1284 SmallVectorImpl<DictionaryAttr> &resultAttrs) {
1285 SMLoc trailingTypesLoc = parser.getCurrentLocation();
1286 SmallVector<Type> types;
1287 if (parser.parseColon())
1288 return failure();
1289 if (!isDirect) {
1290 types.emplace_back();
1291 if (parser.parseType(result&: types.back()))
1292 return failure();
1293 if (parser.parseOptionalComma())
1294 return parser.emitError(
1295 loc: trailingTypesLoc, message: "expected indirect call to have 2 trailing types");
1296 }
1297 SmallVector<Type> argTypes;
1298 SmallVector<Type> resTypes;
1299 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1300 resultTypes&: resTypes, resultAttrs)) {
1301 if (isDirect)
1302 return parser.emitError(loc: trailingTypesLoc,
1303 message: "expected direct call to have 1 trailing types");
1304 return parser.emitError(loc: trailingTypesLoc,
1305 message: "expected trailing function type");
1306 }
1307
1308 if (resTypes.size() > 1)
1309 return parser.emitError(loc: trailingTypesLoc,
1310 message: "expected function with 0 or 1 result");
1311 if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(Val: resTypes[0]))
1312 return parser.emitError(loc: trailingTypesLoc,
1313 message: "expected a non-void result type");
1314
1315 // The head element of the types list matches the callee type for
1316 // indirect calls, while the types list is emtpy for direct calls.
1317 // Append the function input types to resolve the call operation
1318 // operands.
1319 llvm::append_range(C&: types, R&: argTypes);
1320 if (parser.resolveOperands(operands, types, loc: parser.getNameLoc(),
1321 result&: result.operands))
1322 return failure();
1323 if (resTypes.size() != 0)
1324 result.addTypes(newTypes: resTypes);
1325
1326 return success();
1327}
1328
1329/// Parses an optional function pointer operand before the call argument list
1330/// for indirect calls, or stops parsing at the function identifier otherwise.
1331static ParseResult parseOptionalCallFuncPtr(
1332 OpAsmParser &parser,
1333 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) {
1334 OpAsmParser::UnresolvedOperand funcPtrOperand;
1335 OptionalParseResult parseResult = parser.parseOptionalOperand(result&: funcPtrOperand);
1336 if (parseResult.has_value()) {
1337 if (failed(Result: *parseResult))
1338 return *parseResult;
1339 operands.push_back(Elt: funcPtrOperand);
1340 }
1341 return success();
1342}
1343
1344static ParseResult resolveOpBundleOperands(
1345 OpAsmParser &parser, SMLoc loc, OperationState &state,
1346 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands,
1347 ArrayRef<SmallVector<Type>> opBundleOperandTypes,
1348 StringAttr opBundleSizesAttrName) {
1349 unsigned opBundleIndex = 0;
1350 for (const auto &[operands, types] :
1351 llvm::zip_equal(t&: opBundleOperands, u&: opBundleOperandTypes)) {
1352 if (operands.size() != types.size())
1353 return parser.emitError(loc, message: "expected ")
1354 << operands.size()
1355 << " types for operand bundle operands for operand bundle #"
1356 << opBundleIndex << ", but actually got " << types.size();
1357 if (parser.resolveOperands(operands, types, loc, result&: state.operands))
1358 return failure();
1359 }
1360
1361 SmallVector<int32_t> opBundleSizes;
1362 opBundleSizes.reserve(N: opBundleOperands.size());
1363 for (const auto &operands : opBundleOperands)
1364 opBundleSizes.push_back(Elt: operands.size());
1365
1366 state.addAttribute(
1367 name: opBundleSizesAttrName,
1368 attr: DenseI32ArrayAttr::get(context: parser.getContext(), content: opBundleSizes));
1369
1370 return success();
1371}
1372
1373// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1374// `(` ssa-use-list `)`
1375// ( `vararg(` var-callee-type `)` )?
1376// ( `[` op-bundles-list `]` )?
1377// attribute-dict? `:` (type `,`)? function-type
1378ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1379 SymbolRefAttr funcAttr;
1380 TypeAttr varCalleeType;
1381 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1382 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
1383 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1384 ArrayAttr opBundleTags;
1385
1386 // Default to C Calling Convention if no keyword is provided.
1387 result.addAttribute(
1388 name: getCConvAttrName(name: result.name),
1389 attr: CConvAttr::get(context: parser.getContext(),
1390 CallingConv: parseOptionalLLVMKeyword<CConv>(parser, defaultValue: LLVM::CConv::C)));
1391
1392 result.addAttribute(
1393 name: getTailCallKindAttrName(name: result.name),
1394 attr: TailCallKindAttr::get(context: parser.getContext(),
1395 tailCallKind: parseOptionalLLVMKeyword<TailCallKind>(
1396 parser, defaultValue: LLVM::TailCallKind::None)));
1397
1398 // Parse a function pointer for indirect calls.
1399 if (parseOptionalCallFuncPtr(parser, operands))
1400 return failure();
1401 bool isDirect = operands.empty();
1402
1403 // Parse a function identifier for direct calls.
1404 if (isDirect)
1405 if (parser.parseAttribute(result&: funcAttr, attrName: "callee", attrs&: result.attributes))
1406 return failure();
1407
1408 // Parse the function arguments.
1409 if (parser.parseOperandList(result&: operands, delimiter: OpAsmParser::Delimiter::Paren))
1410 return failure();
1411
1412 bool isVarArg = parser.parseOptionalKeyword(keyword: "vararg").succeeded();
1413 if (isVarArg) {
1414 StringAttr varCalleeTypeAttrName =
1415 CallOp::getVarCalleeTypeAttrName(name: result.name);
1416 if (parser.parseLParen().failed() ||
1417 parser
1418 .parseAttribute(result&: varCalleeType, attrName: varCalleeTypeAttrName,
1419 attrs&: result.attributes)
1420 .failed() ||
1421 parser.parseRParen().failed())
1422 return failure();
1423 }
1424
1425 SMLoc opBundlesLoc = parser.getCurrentLocation();
1426 if (std::optional<ParseResult> result = parseOpBundles(
1427 p&: parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1428 result && failed(Result: *result))
1429 return failure();
1430 if (opBundleTags && !opBundleTags.empty())
1431 result.addAttribute(name: CallOp::getOpBundleTagsAttrName(name: result.name).getValue(),
1432 attr: opBundleTags);
1433
1434 if (parser.parseOptionalAttrDict(result&: result.attributes))
1435 return failure();
1436
1437 // Parse the trailing type list and resolve the operands.
1438 SmallVector<DictionaryAttr> argAttrs;
1439 SmallVector<DictionaryAttr> resultAttrs;
1440 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1441 argAttrs, resultAttrs))
1442 return failure();
1443 call_interface_impl::addArgAndResultAttrs(
1444 builder&: parser.getBuilder(), result, argAttrs, resultAttrs,
1445 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
1446 if (resolveOpBundleOperands(parser, loc: opBundlesLoc, state&: result, opBundleOperands,
1447 opBundleOperandTypes,
1448 opBundleSizesAttrName: getOpBundleSizesAttrName(name: result.name)))
1449 return failure();
1450
1451 int32_t numOpBundleOperands = 0;
1452 for (const auto &operands : opBundleOperands)
1453 numOpBundleOperands += operands.size();
1454
1455 result.addAttribute(
1456 name: CallOp::getOperandSegmentSizeAttr(),
1457 attr: parser.getBuilder().getDenseI32ArrayAttr(
1458 values: {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
1459 return success();
1460}
1461
1462LLVMFunctionType CallOp::getCalleeFunctionType() {
1463 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1464 return *varCalleeType;
1465 return getLLVMFuncType(context: getContext(), results: getResultTypes(), args: getArgOperands());
1466}
1467
1468///===---------------------------------------------------------------------===//
1469/// LLVM::InvokeOp
1470///===---------------------------------------------------------------------===//
1471
1472void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1473 ValueRange ops, Block *normal, ValueRange normalOps,
1474 Block *unwind, ValueRange unwindOps) {
1475 auto calleeType = func.getFunctionType();
1476 build(odsBuilder&: builder, odsState&: state, resultTypes: getCallOpResultTypes(calleeType),
1477 var_callee_type: getCallOpVarCalleeType(calleeType), callee: SymbolRefAttr::get(symbol: func), callee_operands: ops,
1478 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalDestOperands: normalOps, unwindDestOperands: unwindOps,
1479 branch_weights: nullptr, CConv: nullptr, op_bundle_operands: {}, op_bundle_tags: {}, normalDest: normal, unwindDest: unwind);
1480}
1481
1482void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1483 FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1484 ValueRange normalOps, Block *unwind,
1485 ValueRange unwindOps) {
1486 build(odsBuilder&: builder, odsState&: state, resultTypes: tys,
1487 /*var_callee_type=*/nullptr, callee, callee_operands: ops, /*arg_attrs=*/nullptr,
1488 /*res_attrs=*/nullptr, normalDestOperands: normalOps, unwindDestOperands: unwindOps, branch_weights: nullptr, CConv: nullptr, op_bundle_operands: {}, op_bundle_tags: {},
1489 normalDest: normal, unwindDest: unwind);
1490}
1491
1492void InvokeOp::build(OpBuilder &builder, OperationState &state,
1493 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1494 ValueRange ops, Block *normal, ValueRange normalOps,
1495 Block *unwind, ValueRange unwindOps) {
1496 build(odsBuilder&: builder, odsState&: state, resultTypes: getCallOpResultTypes(calleeType),
1497 var_callee_type: getCallOpVarCalleeType(calleeType), callee, callee_operands: ops,
1498 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalDestOperands: normalOps, unwindDestOperands: unwindOps,
1499 branch_weights: nullptr, CConv: nullptr, op_bundle_operands: {}, op_bundle_tags: {}, normalDest: normal, unwindDest: unwind);
1500}
1501
1502SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
1503 assert(index < getNumSuccessors() && "invalid successor index");
1504 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1505 : getUnwindDestOperandsMutable());
1506}
1507
1508CallInterfaceCallable InvokeOp::getCallableForCallee() {
1509 // Direct call.
1510 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1511 return calleeAttr;
1512 // Indirect call, callee Value is the first operand.
1513 return getOperand(i: 0);
1514}
1515
1516void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1517 // Direct call.
1518 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1519 auto symRef = cast<SymbolRefAttr>(Val&: callee);
1520 return setCalleeAttr(cast<FlatSymbolRefAttr>(Val&: symRef));
1521 }
1522 // Indirect call, callee Value is the first operand.
1523 return setOperand(i: 0, value: cast<Value>(Val&: callee));
1524}
1525
1526Operation::operand_range InvokeOp::getArgOperands() {
1527 return getCalleeOperands().drop_front(n: getCallee().has_value() ? 0 : 1);
1528}
1529
1530MutableOperandRange InvokeOp::getArgOperandsMutable() {
1531 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1532 getCalleeOperands().size());
1533}
1534
1535LogicalResult InvokeOp::verify() {
1536 if (failed(Result: verifyCallOpVarCalleeType(callOp: *this)))
1537 return failure();
1538
1539 Block *unwindDest = getUnwindDest();
1540 if (unwindDest->empty())
1541 return emitError(message: "must have at least one operation in unwind destination");
1542
1543 // In unwind destination, first operation must be LandingpadOp
1544 if (!isa<LandingpadOp>(Val: unwindDest->front()))
1545 return emitError(message: "first operation in unwind destination should be a "
1546 "llvm.landingpad operation");
1547
1548 if (failed(Result: verifyOperandBundles(op&: *this)))
1549 return failure();
1550
1551 return success();
1552}
1553
1554void InvokeOp::print(OpAsmPrinter &p) {
1555 auto callee = getCallee();
1556 bool isDirect = callee.has_value();
1557
1558 p << ' ';
1559
1560 // Print calling convention.
1561 if (getCConv() != LLVM::CConv::C)
1562 p << stringifyCConv(getCConv()) << ' ';
1563
1564 // Either function name or pointer
1565 if (isDirect)
1566 p.printSymbolName(symbolRef: callee.value());
1567 else
1568 p << getOperand(i: 0);
1569
1570 p << '(' << getCalleeOperands().drop_front(n: isDirect ? 0 : 1) << ')';
1571 p << " to ";
1572 p.printSuccessorAndUseList(successor: getNormalDest(), succOperands: getNormalDestOperands());
1573 p << " unwind ";
1574 p.printSuccessorAndUseList(successor: getUnwindDest(), succOperands: getUnwindDestOperands());
1575
1576 // Print the variadic callee type if the invoke is variadic.
1577 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1578 p << " vararg(" << *varCalleeType << ")";
1579
1580 if (!getOpBundleOperands().empty()) {
1581 p << " ";
1582 printOpBundles(p, op: *this, opBundleOperands: getOpBundleOperands(),
1583 opBundleOperandTypes: getOpBundleOperands().getTypes(), opBundleTags: getOpBundleTags());
1584 }
1585
1586 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
1587 elidedAttrs: {getCalleeAttrName(), getOperandSegmentSizeAttr(),
1588 getCConvAttrName(), getVarCalleeTypeAttrName(),
1589 getOpBundleSizesAttrName(),
1590 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1591 getResAttrsAttrName()});
1592
1593 p << " : ";
1594 if (!isDirect)
1595 p << getOperand(i: 0).getType() << ", ";
1596 call_interface_impl::printFunctionSignature(
1597 p, argTypes: getCalleeOperands().drop_front(n: isDirect ? 0 : 1).getTypes(),
1598 argAttrs: getArgAttrsAttr(),
1599 /*isVariadic=*/false, resultTypes: getResultTypes(), resultAttrs: getResAttrsAttr());
1600}
1601
1602// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1603// `(` ssa-use-list `)`
1604// `to` bb-id (`[` ssa-use-and-type-list `]`)?
1605// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1606// ( `vararg(` var-callee-type `)` )?
1607// ( `[` op-bundles-list `]` )?
1608// attribute-dict? `:` (type `,`)?
1609// function-type-with-argument-attributes
1610ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1611 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
1612 SymbolRefAttr funcAttr;
1613 TypeAttr varCalleeType;
1614 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
1615 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1616 ArrayAttr opBundleTags;
1617 Block *normalDest, *unwindDest;
1618 SmallVector<Value, 4> normalOperands, unwindOperands;
1619 Builder &builder = parser.getBuilder();
1620
1621 // Default to C Calling Convention if no keyword is provided.
1622 result.addAttribute(
1623 name: getCConvAttrName(name: result.name),
1624 attr: CConvAttr::get(context: parser.getContext(),
1625 CallingConv: parseOptionalLLVMKeyword<CConv>(parser, defaultValue: LLVM::CConv::C)));
1626
1627 // Parse a function pointer for indirect calls.
1628 if (parseOptionalCallFuncPtr(parser, operands))
1629 return failure();
1630 bool isDirect = operands.empty();
1631
1632 // Parse a function identifier for direct calls.
1633 if (isDirect && parser.parseAttribute(result&: funcAttr, attrName: "callee", attrs&: result.attributes))
1634 return failure();
1635
1636 // Parse the function arguments.
1637 if (parser.parseOperandList(result&: operands, delimiter: OpAsmParser::Delimiter::Paren) ||
1638 parser.parseKeyword(keyword: "to") ||
1639 parser.parseSuccessorAndUseList(dest&: normalDest, operands&: normalOperands) ||
1640 parser.parseKeyword(keyword: "unwind") ||
1641 parser.parseSuccessorAndUseList(dest&: unwindDest, operands&: unwindOperands))
1642 return failure();
1643
1644 bool isVarArg = parser.parseOptionalKeyword(keyword: "vararg").succeeded();
1645 if (isVarArg) {
1646 StringAttr varCalleeTypeAttrName =
1647 InvokeOp::getVarCalleeTypeAttrName(name: result.name);
1648 if (parser.parseLParen().failed() ||
1649 parser
1650 .parseAttribute(result&: varCalleeType, attrName: varCalleeTypeAttrName,
1651 attrs&: result.attributes)
1652 .failed() ||
1653 parser.parseRParen().failed())
1654 return failure();
1655 }
1656
1657 SMLoc opBundlesLoc = parser.getCurrentLocation();
1658 if (std::optional<ParseResult> result = parseOpBundles(
1659 p&: parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1660 result && failed(Result: *result))
1661 return failure();
1662 if (opBundleTags && !opBundleTags.empty())
1663 result.addAttribute(
1664 name: InvokeOp::getOpBundleTagsAttrName(name: result.name).getValue(),
1665 attr: opBundleTags);
1666
1667 if (parser.parseOptionalAttrDict(result&: result.attributes))
1668 return failure();
1669
1670 // Parse the trailing type list and resolve the function operands.
1671 SmallVector<DictionaryAttr> argAttrs;
1672 SmallVector<DictionaryAttr> resultAttrs;
1673 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1674 argAttrs, resultAttrs))
1675 return failure();
1676 call_interface_impl::addArgAndResultAttrs(
1677 builder&: parser.getBuilder(), result, argAttrs, resultAttrs,
1678 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
1679
1680 if (resolveOpBundleOperands(parser, loc: opBundlesLoc, state&: result, opBundleOperands,
1681 opBundleOperandTypes,
1682 opBundleSizesAttrName: getOpBundleSizesAttrName(name: result.name)))
1683 return failure();
1684
1685 result.addSuccessors(newSuccessors: {normalDest, unwindDest});
1686 result.addOperands(newOperands: normalOperands);
1687 result.addOperands(newOperands: unwindOperands);
1688
1689 int32_t numOpBundleOperands = 0;
1690 for (const auto &operands : opBundleOperands)
1691 numOpBundleOperands += operands.size();
1692
1693 result.addAttribute(
1694 name: InvokeOp::getOperandSegmentSizeAttr(),
1695 attr: builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(operands.size()),
1696 static_cast<int32_t>(normalOperands.size()),
1697 static_cast<int32_t>(unwindOperands.size()),
1698 numOpBundleOperands}));
1699 return success();
1700}
1701
1702LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1703 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1704 return *varCalleeType;
1705 return getLLVMFuncType(context: getContext(), results: getResultTypes(), args: getArgOperands());
1706}
1707
1708///===----------------------------------------------------------------------===//
1709/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1710///===----------------------------------------------------------------------===//
1711
1712LogicalResult LandingpadOp::verify() {
1713 Value value;
1714 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1715 if (!func.getPersonality())
1716 return emitError(
1717 message: "llvm.landingpad needs to be in a function with a personality");
1718 }
1719
1720 // Consistency of llvm.landingpad result types is checked in
1721 // LLVMFuncOp::verify().
1722
1723 if (!getCleanup() && getOperands().empty())
1724 return emitError(message: "landingpad instruction expects at least one clause or "
1725 "cleanup attribute");
1726
1727 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1728 value = getOperand(i: idx);
1729 bool isFilter = llvm::isa<LLVMArrayType>(Val: value.getType());
1730 if (isFilter) {
1731 // FIXME: Verify filter clauses when arrays are appropriately handled
1732 } else {
1733 // catch - global addresses only.
1734 // Bitcast ops should have global addresses as their args.
1735 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1736 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1737 continue;
1738 return emitError(message: "constant clauses expected").attachNote(noteLoc: bcOp.getLoc())
1739 << "global addresses expected as operand to "
1740 "bitcast used in clauses for landingpad";
1741 }
1742 // ZeroOp and AddressOfOp allowed
1743 if (value.getDefiningOp<ZeroOp>())
1744 continue;
1745 if (value.getDefiningOp<AddressOfOp>())
1746 continue;
1747 return emitError(message: "clause #")
1748 << idx << " is not a known constant - null, addressof, bitcast";
1749 }
1750 }
1751 return success();
1752}
1753
1754void LandingpadOp::print(OpAsmPrinter &p) {
1755 p << (getCleanup() ? " cleanup " : " ");
1756
1757 // Clauses
1758 for (auto value : getOperands()) {
1759 // Similar to llvm - if clause is an array type then it is filter
1760 // clause else catch clause
1761 bool isArrayTy = llvm::isa<LLVMArrayType>(Val: value.getType());
1762 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1763 << value.getType() << ") ";
1764 }
1765
1766 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), elidedAttrs: {"cleanup"});
1767
1768 p << ": " << getType();
1769}
1770
1771// <operation> ::= `llvm.landingpad` `cleanup`?
1772// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1773ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1774 // Check for cleanup
1775 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "cleanup")))
1776 result.addAttribute(name: "cleanup", attr: parser.getBuilder().getUnitAttr());
1777
1778 // Parse clauses with types
1779 while (succeeded(Result: parser.parseOptionalLParen()) &&
1780 (succeeded(Result: parser.parseOptionalKeyword(keyword: "filter")) ||
1781 succeeded(Result: parser.parseOptionalKeyword(keyword: "catch")))) {
1782 OpAsmParser::UnresolvedOperand operand;
1783 Type ty;
1784 if (parser.parseOperand(result&: operand) || parser.parseColon() ||
1785 parser.parseType(result&: ty) ||
1786 parser.resolveOperand(operand, type: ty, result&: result.operands) ||
1787 parser.parseRParen())
1788 return failure();
1789 }
1790
1791 Type type;
1792 if (parser.parseColon() || parser.parseType(result&: type))
1793 return failure();
1794
1795 result.addTypes(newTypes: type);
1796 return success();
1797}
1798
1799//===----------------------------------------------------------------------===//
1800// ExtractValueOp
1801//===----------------------------------------------------------------------===//
1802
1803/// Extract the type at `position` in the LLVM IR aggregate type
1804/// `containerType`. Each element of `position` is an index into a nested
1805/// aggregate type. Return the resulting type or emit an error.
1806static Type getInsertExtractValueElementType(
1807 function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1808 ArrayRef<int64_t> position) {
1809 Type llvmType = containerType;
1810 if (!isCompatibleType(type: containerType)) {
1811 emitError("expected LLVM IR Dialect type, got ") << containerType;
1812 return {};
1813 }
1814
1815 // Infer the element type from the structure type: iteratively step inside the
1816 // type by taking the element type, indexed by the position attribute for
1817 // structures. Check the position index before accessing, it is supposed to
1818 // be in bounds.
1819 for (int64_t idx : position) {
1820 if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(Val&: llvmType)) {
1821 if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1822 emitError("position out of bounds: ") << idx;
1823 return {};
1824 }
1825 llvmType = arrayType.getElementType();
1826 } else if (auto structType = llvm::dyn_cast<LLVMStructType>(Val&: llvmType)) {
1827 if (idx < 0 ||
1828 static_cast<unsigned>(idx) >= structType.getBody().size()) {
1829 emitError("position out of bounds: ") << idx;
1830 return {};
1831 }
1832 llvmType = structType.getBody()[idx];
1833 } else {
1834 emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1835 return {};
1836 }
1837 }
1838 return llvmType;
1839}
1840
1841/// Extract the type at `position` in the wrapped LLVM IR aggregate type
1842/// `containerType`.
1843static Type getInsertExtractValueElementType(Type llvmType,
1844 ArrayRef<int64_t> position) {
1845 for (int64_t idx : position) {
1846 if (auto structType = llvm::dyn_cast<LLVMStructType>(Val&: llvmType))
1847 llvmType = structType.getBody()[idx];
1848 else
1849 llvmType = llvm::cast<LLVMArrayType>(Val&: llvmType).getElementType();
1850 }
1851 return llvmType;
1852}
1853
1854OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1855 if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
1856 SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
1857 newPos.append(in_start: getPosition().begin(), in_end: getPosition().end());
1858 setPosition(newPos);
1859 getContainerMutable().set(extractValueOp.getContainer());
1860 return getResult();
1861 }
1862
1863 {
1864 DenseElementsAttr constval;
1865 matchPattern(value: getContainer(), pattern: m_Constant(bind_value: &constval));
1866 if (constval && constval.getElementType() == getType()) {
1867 if (isa<SplatElementsAttr>(Val: constval))
1868 return constval.getSplatValue<Attribute>();
1869 if (getPosition().size() == 1)
1870 return constval.getValues<Attribute>()[getPosition()[0]];
1871 }
1872 }
1873
1874 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1875 OpFoldResult result = {};
1876 ArrayRef<int64_t> extractPos = getPosition();
1877 bool switchedToInsertedValue = false;
1878 while (insertValueOp) {
1879 ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
1880 auto extractPosSize = extractPos.size();
1881 auto insertPosSize = insertPos.size();
1882
1883 // Case 1: Exact match of positions.
1884 if (extractPos == insertPos)
1885 return insertValueOp.getValue();
1886
1887 // Case 2: Insert position is a prefix of extract position. Continue
1888 // traversal with the inserted value. Example:
1889 // ```
1890 // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
1891 // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
1892 // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
1893 // %3 = llvm.insertvalue %2, %foo[0]
1894 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1895 // %4 = llvm.extractvalue %3[0, 0]
1896 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1897 // ```
1898 // In the above example, %4 is folded to %arg1.
1899 if (extractPosSize > insertPosSize &&
1900 extractPos.take_front(N: insertPosSize) == insertPos) {
1901 insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
1902 extractPos = extractPos.drop_front(N: insertPosSize);
1903 switchedToInsertedValue = true;
1904 continue;
1905 }
1906
1907 // Case 3: Try to continue the traversal with the container value.
1908 unsigned min = std::min(a: extractPosSize, b: insertPosSize);
1909
1910 // If one is fully prefix of the other, stop propagating back as it will
1911 // miss dependencies. For instance, %3 should not fold to %f0 in the
1912 // following example:
1913 // ```
1914 // %1 = llvm.insertvalue %f0, %0[0, 0] :
1915 // !llvm.array<4 x !llvm.array<4 x f32>>
1916 // %2 = llvm.insertvalue %arr, %1[0] :
1917 // !llvm.array<4 x !llvm.array<4 x f32>>
1918 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
1919 // ```
1920 if (extractPos.take_front(N: min) == insertPos.take_front(N: min))
1921 return result;
1922 // If neither a prefix, nor the exact position, we can extract out of the
1923 // value being inserted into. Moreover, we can try again if that operand
1924 // is itself an insertvalue expression.
1925 if (!switchedToInsertedValue) {
1926 // Do not swap out the container operand if we decided earlier to
1927 // continue the traversal with the inserted value (Case 2).
1928 getContainerMutable().assign(value: insertValueOp.getContainer());
1929 result = getResult();
1930 }
1931 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1932 }
1933 return result;
1934}
1935
1936LogicalResult ExtractValueOp::verify() {
1937 auto emitError = [this](StringRef msg) { return emitOpError(message: msg); };
1938 Type valueType = getInsertExtractValueElementType(
1939 emitError, containerType: getContainer().getType(), position: getPosition());
1940 if (!valueType)
1941 return failure();
1942
1943 if (getRes().getType() != valueType)
1944 return emitOpError() << "Type mismatch: extracting from "
1945 << getContainer().getType() << " should produce "
1946 << valueType << " but this op returns "
1947 << getRes().getType();
1948 return success();
1949}
1950
1951void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
1952 Value container, ArrayRef<int64_t> position) {
1953 build(odsBuilder&: builder, odsState&: state,
1954 res: getInsertExtractValueElementType(llvmType: container.getType(), position),
1955 container, position: builder.getAttr<DenseI64ArrayAttr>(args&: position));
1956}
1957
1958//===----------------------------------------------------------------------===//
1959// InsertValueOp
1960//===----------------------------------------------------------------------===//
1961
1962/// Infer the value type from the container type and position.
1963static ParseResult
1964parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,
1965 Type containerType,
1966 DenseI64ArrayAttr position) {
1967 valueType = getInsertExtractValueElementType(
1968 emitError: [&](StringRef msg) {
1969 return parser.emitError(loc: parser.getCurrentLocation(), message: msg);
1970 },
1971 containerType, position: position.asArrayRef());
1972 return success(IsSuccess: !!valueType);
1973}
1974
1975/// Nothing to print for an inferred type.
1976static void printInsertExtractValueElementType(AsmPrinter &printer,
1977 Operation *op, Type valueType,
1978 Type containerType,
1979 DenseI64ArrayAttr position) {}
1980
1981LogicalResult InsertValueOp::verify() {
1982 auto emitError = [this](StringRef msg) { return emitOpError(message: msg); };
1983 Type valueType = getInsertExtractValueElementType(
1984 emitError, containerType: getContainer().getType(), position: getPosition());
1985 if (!valueType)
1986 return failure();
1987
1988 if (getValue().getType() != valueType)
1989 return emitOpError() << "Type mismatch: cannot insert "
1990 << getValue().getType() << " into "
1991 << getContainer().getType();
1992
1993 return success();
1994}
1995
1996//===----------------------------------------------------------------------===//
1997// ReturnOp
1998//===----------------------------------------------------------------------===//
1999
2000LogicalResult ReturnOp::verify() {
2001 auto parent = (*this)->getParentOfType<LLVMFuncOp>();
2002 if (!parent)
2003 return success();
2004
2005 Type expectedType = parent.getFunctionType().getReturnType();
2006 if (llvm::isa<LLVMVoidType>(Val: expectedType)) {
2007 if (!getArg())
2008 return success();
2009 InFlightDiagnostic diag = emitOpError(message: "expected no operands");
2010 diag.attachNote(noteLoc: parent->getLoc()) << "when returning from function";
2011 return diag;
2012 }
2013 if (!getArg()) {
2014 if (llvm::isa<LLVMVoidType>(Val: expectedType))
2015 return success();
2016 InFlightDiagnostic diag = emitOpError(message: "expected 1 operand");
2017 diag.attachNote(noteLoc: parent->getLoc()) << "when returning from function";
2018 return diag;
2019 }
2020 if (expectedType != getArg().getType()) {
2021 InFlightDiagnostic diag = emitOpError(message: "mismatching result types");
2022 diag.attachNote(noteLoc: parent->getLoc()) << "when returning from function";
2023 return diag;
2024 }
2025 return success();
2026}
2027
2028//===----------------------------------------------------------------------===//
2029// LLVM::AddressOfOp.
2030//===----------------------------------------------------------------------===//
2031
2032static Operation *parentLLVMModule(Operation *op) {
2033 Operation *module = op->getParentOp();
2034 while (module && !satisfiesLLVMModule(op: module))
2035 module = module->getParentOp();
2036 assert(module && "unexpected operation outside of a module");
2037 return module;
2038}
2039
2040GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
2041 return dyn_cast_or_null<GlobalOp>(
2042 Val: symbolTable.lookupSymbolIn(symbolTableOp: parentLLVMModule(op: *this), name: getGlobalNameAttr()));
2043}
2044
2045LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
2046 return dyn_cast_or_null<LLVMFuncOp>(
2047 Val: symbolTable.lookupSymbolIn(symbolTableOp: parentLLVMModule(op: *this), name: getGlobalNameAttr()));
2048}
2049
2050AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
2051 return dyn_cast_or_null<AliasOp>(
2052 Val: symbolTable.lookupSymbolIn(symbolTableOp: parentLLVMModule(op: *this), name: getGlobalNameAttr()));
2053}
2054
2055LogicalResult
2056AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2057 Operation *symbol =
2058 symbolTable.lookupSymbolIn(symbolTableOp: parentLLVMModule(op: *this), name: getGlobalNameAttr());
2059
2060 auto global = dyn_cast_or_null<GlobalOp>(Val: symbol);
2061 auto function = dyn_cast_or_null<LLVMFuncOp>(Val: symbol);
2062 auto alias = dyn_cast_or_null<AliasOp>(Val: symbol);
2063
2064 if (!global && !function && !alias)
2065 return emitOpError(message: "must reference a global defined by 'llvm.mlir.global', "
2066 "'llvm.mlir.alias' or 'llvm.func'");
2067
2068 LLVMPointerType type = getType();
2069 if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
2070 (alias && alias.getAddrSpace() != type.getAddressSpace()))
2071 return emitOpError(message: "pointer address space must match address space of the "
2072 "referenced global or alias");
2073
2074 return success();
2075}
2076
2077// AddressOfOp constant-folds to the global symbol name.
2078OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
2079 return getGlobalNameAttr();
2080}
2081
2082//===----------------------------------------------------------------------===//
2083// LLVM::DSOLocalEquivalentOp
2084//===----------------------------------------------------------------------===//
2085
2086LLVMFuncOp
2087DSOLocalEquivalentOp::getFunction(SymbolTableCollection &symbolTable) {
2088 return dyn_cast_or_null<LLVMFuncOp>(Val: symbolTable.lookupSymbolIn(
2089 symbolTableOp: parentLLVMModule(op: *this), name: getFunctionNameAttr()));
2090}
2091
2092AliasOp DSOLocalEquivalentOp::getAlias(SymbolTableCollection &symbolTable) {
2093 return dyn_cast_or_null<AliasOp>(Val: symbolTable.lookupSymbolIn(
2094 symbolTableOp: parentLLVMModule(op: *this), name: getFunctionNameAttr()));
2095}
2096
2097LogicalResult
2098DSOLocalEquivalentOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2099 Operation *symbol = symbolTable.lookupSymbolIn(symbolTableOp: parentLLVMModule(op: *this),
2100 name: getFunctionNameAttr());
2101 auto function = dyn_cast_or_null<LLVMFuncOp>(Val: symbol);
2102 auto alias = dyn_cast_or_null<AliasOp>(Val: symbol);
2103
2104 if (!function && !alias)
2105 return emitOpError(
2106 message: "must reference a global defined by 'llvm.func' or 'llvm.mlir.alias'");
2107
2108 if (alias) {
2109 if (alias.getInitializer()
2110 .walk(callback: [&](AddressOfOp addrOp) {
2111 if (addrOp.getGlobal(symbolTable))
2112 return WalkResult::interrupt();
2113 return WalkResult::advance();
2114 })
2115 .wasInterrupted())
2116 return emitOpError(message: "must reference an alias to a function");
2117 }
2118
2119 if ((function && function.getLinkage() == LLVM::Linkage::ExternWeak) ||
2120 (alias && alias.getLinkage() == LLVM::Linkage::ExternWeak))
2121 return emitOpError(
2122 message: "target function with 'extern_weak' linkage not allowed");
2123
2124 return success();
2125}
2126
2127/// Fold a dso_local_equivalent operation to a dedicated dso_local_equivalent
2128/// attribute.
2129OpFoldResult DSOLocalEquivalentOp::fold(FoldAdaptor) {
2130 return DSOLocalEquivalentAttr::get(context: getContext(), sym: getFunctionNameAttr());
2131}
2132
2133//===----------------------------------------------------------------------===//
2134// Verifier for LLVM::ComdatOp.
2135//===----------------------------------------------------------------------===//
2136
2137void ComdatOp::build(OpBuilder &builder, OperationState &result,
2138 StringRef symName) {
2139 result.addAttribute(name: getSymNameAttrName(name: result.name),
2140 attr: builder.getStringAttr(bytes: symName));
2141 Region *body = result.addRegion();
2142 body->emplaceBlock();
2143}
2144
2145LogicalResult ComdatOp::verifyRegions() {
2146 Region &body = getBody();
2147 for (Operation &op : body.getOps())
2148 if (!isa<ComdatSelectorOp>(Val: op))
2149 return op.emitError(
2150 message: "only comdat selector symbols can appear in a comdat region");
2151
2152 return success();
2153}
2154
2155//===----------------------------------------------------------------------===//
2156// Builder, printer and verifier for LLVM::GlobalOp.
2157//===----------------------------------------------------------------------===//
2158
2159void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
2160 bool isConstant, Linkage linkage, StringRef name,
2161 Attribute value, uint64_t alignment, unsigned addrSpace,
2162 bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
2163 ArrayRef<NamedAttribute> attrs,
2164 ArrayRef<Attribute> dbgExprs) {
2165 result.addAttribute(name: getSymNameAttrName(name: result.name),
2166 attr: builder.getStringAttr(bytes: name));
2167 result.addAttribute(name: getGlobalTypeAttrName(name: result.name), attr: TypeAttr::get(type));
2168 if (isConstant)
2169 result.addAttribute(name: getConstantAttrName(name: result.name),
2170 attr: builder.getUnitAttr());
2171 if (value)
2172 result.addAttribute(name: getValueAttrName(name: result.name), attr: value);
2173 if (dsoLocal)
2174 result.addAttribute(name: getDsoLocalAttrName(name: result.name),
2175 attr: builder.getUnitAttr());
2176 if (threadLocal)
2177 result.addAttribute(name: getThreadLocal_AttrName(name: result.name),
2178 attr: builder.getUnitAttr());
2179 if (comdat)
2180 result.addAttribute(name: getComdatAttrName(name: result.name), attr: comdat);
2181
2182 // Only add an alignment attribute if the "alignment" input
2183 // is different from 0. The value must also be a power of two, but
2184 // this is tested in GlobalOp::verify, not here.
2185 if (alignment != 0)
2186 result.addAttribute(name: getAlignmentAttrName(name: result.name),
2187 attr: builder.getI64IntegerAttr(value: alignment));
2188
2189 result.addAttribute(name: getLinkageAttrName(name: result.name),
2190 attr: LinkageAttr::get(context: builder.getContext(), linkage));
2191 if (addrSpace != 0)
2192 result.addAttribute(name: getAddrSpaceAttrName(name: result.name),
2193 attr: builder.getI32IntegerAttr(value: addrSpace));
2194 result.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
2195
2196 if (!dbgExprs.empty())
2197 result.addAttribute(name: getDbgExprsAttrName(name: result.name),
2198 attr: ArrayAttr::get(context: builder.getContext(), value: dbgExprs));
2199
2200 result.addRegion();
2201}
2202
2203template <typename OpType>
2204static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op) {
2205 p << ' ' << stringifyLinkage(op.getLinkage()) << ' ';
2206 StringRef visibility = stringifyVisibility(op.getVisibility_());
2207 if (!visibility.empty())
2208 p << visibility << ' ';
2209 if (op.getThreadLocal_())
2210 p << "thread_local ";
2211 if (auto unnamedAddr = op.getUnnamedAddr()) {
2212 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2213 if (!str.empty())
2214 p << str << ' ';
2215 }
2216}
2217
2218void GlobalOp::print(OpAsmPrinter &p) {
2219 printCommonGlobalAndAlias<GlobalOp>(p, op: *this);
2220 if (getConstant())
2221 p << "constant ";
2222 p.printSymbolName(symbolRef: getSymName());
2223 p << '(';
2224 if (auto value = getValueOrNull())
2225 p.printAttribute(attr: value);
2226 p << ')';
2227 if (auto comdat = getComdat())
2228 p << " comdat(" << *comdat << ')';
2229
2230 // Note that the alignment attribute is printed using the
2231 // default syntax here, even though it is an inherent attribute
2232 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
2233 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
2234 elidedAttrs: {SymbolTable::getSymbolAttrName(),
2235 getGlobalTypeAttrName(), getConstantAttrName(),
2236 getValueAttrName(), getLinkageAttrName(),
2237 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2238 getVisibility_AttrName(), getComdatAttrName()});
2239
2240 // Print the trailing type unless it's a string global.
2241 if (llvm::dyn_cast_or_null<StringAttr>(Val: getValueOrNull()))
2242 return;
2243 p << " : " << getType();
2244
2245 Region &initializer = getInitializerRegion();
2246 if (!initializer.empty()) {
2247 p << ' ';
2248 p.printRegion(blocks&: initializer, /*printEntryBlockArgs=*/false);
2249 }
2250}
2251
2252static LogicalResult verifyComdat(Operation *op,
2253 std::optional<SymbolRefAttr> attr) {
2254 if (!attr)
2255 return success();
2256
2257 auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(from: op, symbol: *attr);
2258 if (!isa_and_nonnull<ComdatSelectorOp>(Val: comdatSelector))
2259 return op->emitError() << "expected comdat symbol";
2260
2261 return success();
2262}
2263
2264static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
2265 llvm::DenseSet<BlockTagAttr> blockTags;
2266 // Note that presence of `BlockTagOp`s currently can't prevent an unrecheable
2267 // block to be removed by canonicalizer's region simplify pass, which needs to
2268 // be dialect aware to allow extra constraints to be described.
2269 WalkResult res = funcOp.walk(callback: [&](BlockTagOp blockTagOp) {
2270 if (blockTags.contains(V: blockTagOp.getTag())) {
2271 blockTagOp.emitError()
2272 << "duplicate block tag '" << blockTagOp.getTag().getId()
2273 << "' in the same function: ";
2274 return WalkResult::interrupt();
2275 }
2276 blockTags.insert(V: blockTagOp.getTag());
2277 return WalkResult::advance();
2278 });
2279
2280 return failure(IsFailure: res.wasInterrupted());
2281}
2282
2283/// Parse common attributes that might show up in the same order in both
2284/// GlobalOp and AliasOp.
2285template <typename OpType>
2286static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser,
2287 OperationState &result) {
2288 MLIRContext *ctx = parser.getContext();
2289 // Parse optional linkage, default to External.
2290 result.addAttribute(
2291 OpType::getLinkageAttrName(result.name),
2292 LLVM::LinkageAttr::get(context: ctx, linkage: parseOptionalLLVMKeyword<Linkage>(
2293 parser, defaultValue: LLVM::Linkage::External)));
2294
2295 // Parse optional visibility, default to Default.
2296 result.addAttribute(OpType::getVisibility_AttrName(result.name),
2297 parser.getBuilder().getI64IntegerAttr(
2298 value: parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2299 parser, defaultValue: LLVM::Visibility::Default)));
2300
2301 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "thread_local")))
2302 result.addAttribute(OpType::getThreadLocal_AttrName(result.name),
2303 parser.getBuilder().getUnitAttr());
2304
2305 // Parse optional UnnamedAddr, default to None.
2306 result.addAttribute(OpType::getUnnamedAddrAttrName(result.name),
2307 parser.getBuilder().getI64IntegerAttr(
2308 value: parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2309 parser, defaultValue: LLVM::UnnamedAddr::None)));
2310
2311 return success();
2312}
2313
2314// operation ::= `llvm.mlir.global` linkage? visibility?
2315// (`unnamed_addr` | `local_unnamed_addr`)?
2316// `thread_local`? `constant`? `@` identifier
2317// `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
2318// attribute-list? (`:` type)? region?
2319//
2320// The type can be omitted for string attributes, in which case it will be
2321// inferred from the value of the string as [strlen(value) x i8].
2322ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
2323 // Call into common parsing between GlobalOp and AliasOp.
2324 if (parseCommonGlobalAndAlias<GlobalOp>(parser, result).failed())
2325 return failure();
2326
2327 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "constant")))
2328 result.addAttribute(name: getConstantAttrName(name: result.name),
2329 attr: parser.getBuilder().getUnitAttr());
2330
2331 StringAttr name;
2332 if (parser.parseSymbolName(result&: name, attrName: getSymNameAttrName(name: result.name),
2333 attrs&: result.attributes) ||
2334 parser.parseLParen())
2335 return failure();
2336
2337 Attribute value;
2338 if (parser.parseOptionalRParen()) {
2339 if (parser.parseAttribute(result&: value, attrName: getValueAttrName(name: result.name),
2340 attrs&: result.attributes) ||
2341 parser.parseRParen())
2342 return failure();
2343 }
2344
2345 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "comdat"))) {
2346 SymbolRefAttr comdat;
2347 if (parser.parseLParen() || parser.parseAttribute(result&: comdat) ||
2348 parser.parseRParen())
2349 return failure();
2350
2351 result.addAttribute(name: getComdatAttrName(name: result.name), attr: comdat);
2352 }
2353
2354 SmallVector<Type, 1> types;
2355 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
2356 parser.parseOptionalColonTypeList(result&: types))
2357 return failure();
2358
2359 if (types.size() > 1)
2360 return parser.emitError(loc: parser.getNameLoc(), message: "expected zero or one type");
2361
2362 Region &initRegion = *result.addRegion();
2363 if (types.empty()) {
2364 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(Val&: value)) {
2365 MLIRContext *context = parser.getContext();
2366 auto arrayType = LLVM::LLVMArrayType::get(elementType: IntegerType::get(context, width: 8),
2367 numElements: strAttr.getValue().size());
2368 types.push_back(Elt: arrayType);
2369 } else {
2370 return parser.emitError(loc: parser.getNameLoc(),
2371 message: "type can only be omitted for string globals");
2372 }
2373 } else {
2374 OptionalParseResult parseResult =
2375 parser.parseOptionalRegion(region&: initRegion, /*arguments=*/{},
2376 /*argTypes=*/enableNameShadowing: {});
2377 if (parseResult.has_value() && failed(Result: *parseResult))
2378 return failure();
2379 }
2380
2381 result.addAttribute(name: getGlobalTypeAttrName(name: result.name),
2382 attr: TypeAttr::get(type: types[0]));
2383 return success();
2384}
2385
2386static bool isZeroAttribute(Attribute value) {
2387 if (auto intValue = llvm::dyn_cast<IntegerAttr>(Val&: value))
2388 return intValue.getValue().isZero();
2389 if (auto fpValue = llvm::dyn_cast<FloatAttr>(Val&: value))
2390 return fpValue.getValue().isZero();
2391 if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(Val&: value))
2392 return isZeroAttribute(value: splatValue.getSplatValue<Attribute>());
2393 if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(Val&: value))
2394 return llvm::all_of(Range: elementsValue.getValues<Attribute>(), P: isZeroAttribute);
2395 if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(Val&: value))
2396 return llvm::all_of(Range: arrayValue.getValue(), P: isZeroAttribute);
2397 return false;
2398}
2399
2400LogicalResult GlobalOp::verify() {
2401 bool validType = isCompatibleOuterType(type: getType())
2402 ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2403 LLVMMetadataType, LLVMLabelType>(Val: getType())
2404 : llvm::isa<PointerElementTypeInterface>(Val: getType());
2405 if (!validType)
2406 return emitOpError(
2407 message: "expects type to be a valid element type for an LLVM global");
2408 if ((*this)->getParentOp() && !satisfiesLLVMModule(op: (*this)->getParentOp()))
2409 return emitOpError(message: "must appear at the module level");
2410
2411 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(Val: getValueOrNull())) {
2412 auto type = llvm::dyn_cast<LLVMArrayType>(Val: getType());
2413 IntegerType elementType =
2414 type ? llvm::dyn_cast<IntegerType>(Val: type.getElementType()) : nullptr;
2415 if (!elementType || elementType.getWidth() != 8 ||
2416 type.getNumElements() != strAttr.getValue().size())
2417 return emitOpError(
2418 message: "requires an i8 array type of the length equal to that of the string "
2419 "attribute");
2420 }
2421
2422 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(Val: getType())) {
2423 if (!targetExtType.hasProperty(Prop: LLVMTargetExtType::CanBeGlobal))
2424 return emitOpError()
2425 << "this target extension type cannot be used in a global";
2426
2427 if (Attribute value = getValueOrNull())
2428 return emitOpError() << "global with target extension type can only be "
2429 "initialized with zero-initializer";
2430 }
2431
2432 if (getLinkage() == Linkage::Common) {
2433 if (Attribute value = getValueOrNull()) {
2434 if (!isZeroAttribute(value)) {
2435 return emitOpError()
2436 << "expected zero value for '"
2437 << stringifyLinkage(Linkage::Common) << "' linkage";
2438 }
2439 }
2440 }
2441
2442 if (getLinkage() == Linkage::Appending) {
2443 if (!llvm::isa<LLVMArrayType>(Val: getType())) {
2444 return emitOpError() << "expected array type for '"
2445 << stringifyLinkage(Linkage::Appending)
2446 << "' linkage";
2447 }
2448 }
2449
2450 if (failed(Result: verifyComdat(op: *this, attr: getComdat())))
2451 return failure();
2452
2453 std::optional<uint64_t> alignAttr = getAlignment();
2454 if (alignAttr.has_value()) {
2455 uint64_t value = alignAttr.value();
2456 if (!llvm::isPowerOf2_64(Value: value))
2457 return emitError() << "alignment attribute is not a power of 2";
2458 }
2459
2460 return success();
2461}
2462
2463LogicalResult GlobalOp::verifyRegions() {
2464 if (Block *b = getInitializerBlock()) {
2465 ReturnOp ret = cast<ReturnOp>(Val: b->getTerminator());
2466 if (ret.operand_type_begin() == ret.operand_type_end())
2467 return emitOpError(message: "initializer region cannot return void");
2468 if (*ret.operand_type_begin() != getType())
2469 return emitOpError(message: "initializer region type ")
2470 << *ret.operand_type_begin() << " does not match global type "
2471 << getType();
2472
2473 for (Operation &op : *b) {
2474 auto iface = dyn_cast<MemoryEffectOpInterface>(Val&: op);
2475 if (!iface || !iface.hasNoEffect())
2476 return op.emitError()
2477 << "ops with side effects not allowed in global initializers";
2478 }
2479
2480 if (getValueOrNull())
2481 return emitOpError(message: "cannot have both initializer value and region");
2482 }
2483
2484 return success();
2485}
2486
2487//===----------------------------------------------------------------------===//
2488// LLVM::GlobalCtorsOp
2489//===----------------------------------------------------------------------===//
2490
2491LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
2492 if (data.empty())
2493 return success();
2494
2495 if (llvm::all_of(Range: data.getAsRange<Attribute>(), P: [](Attribute v) {
2496 return isa<FlatSymbolRefAttr, ZeroAttr>(Val: v);
2497 }))
2498 return success();
2499 return op->emitError(message: "data element must be symbol or #llvm.zero");
2500}
2501
2502LogicalResult
2503GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2504 for (Attribute ctor : getCtors()) {
2505 if (failed(Result: verifySymbolAttrUse(symbol: llvm::cast<FlatSymbolRefAttr>(Val&: ctor), op: *this,
2506 symbolTable)))
2507 return failure();
2508 }
2509 return success();
2510}
2511
2512LogicalResult GlobalCtorsOp::verify() {
2513 if (checkGlobalXtorData(op: *this, data: getData()).failed())
2514 return failure();
2515
2516 if (getCtors().size() == getPriorities().size() &&
2517 getCtors().size() == getData().size())
2518 return success();
2519 return emitError(
2520 message: "ctors, priorities, and data must have the same number of elements");
2521}
2522
2523//===----------------------------------------------------------------------===//
2524// LLVM::GlobalDtorsOp
2525//===----------------------------------------------------------------------===//
2526
2527LogicalResult
2528GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2529 for (Attribute dtor : getDtors()) {
2530 if (failed(Result: verifySymbolAttrUse(symbol: llvm::cast<FlatSymbolRefAttr>(Val&: dtor), op: *this,
2531 symbolTable)))
2532 return failure();
2533 }
2534 return success();
2535}
2536
2537LogicalResult GlobalDtorsOp::verify() {
2538 if (checkGlobalXtorData(op: *this, data: getData()).failed())
2539 return failure();
2540
2541 if (getDtors().size() == getPriorities().size() &&
2542 getDtors().size() == getData().size())
2543 return success();
2544 return emitError(
2545 message: "dtors, priorities, and data must have the same number of elements");
2546}
2547
2548//===----------------------------------------------------------------------===//
2549// Builder, printer and verifier for LLVM::AliasOp.
2550//===----------------------------------------------------------------------===//
2551
2552void AliasOp::build(OpBuilder &builder, OperationState &result, Type type,
2553 Linkage linkage, StringRef name, bool dsoLocal,
2554 bool threadLocal, ArrayRef<NamedAttribute> attrs) {
2555 result.addAttribute(name: getSymNameAttrName(name: result.name),
2556 attr: builder.getStringAttr(bytes: name));
2557 result.addAttribute(name: getAliasTypeAttrName(name: result.name), attr: TypeAttr::get(type));
2558 if (dsoLocal)
2559 result.addAttribute(name: getDsoLocalAttrName(name: result.name),
2560 attr: builder.getUnitAttr());
2561 if (threadLocal)
2562 result.addAttribute(name: getThreadLocal_AttrName(name: result.name),
2563 attr: builder.getUnitAttr());
2564
2565 result.addAttribute(name: getLinkageAttrName(name: result.name),
2566 attr: LinkageAttr::get(context: builder.getContext(), linkage));
2567 result.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
2568
2569 result.addRegion();
2570}
2571
2572void AliasOp::print(OpAsmPrinter &p) {
2573 printCommonGlobalAndAlias<AliasOp>(p, op: *this);
2574
2575 p.printSymbolName(symbolRef: getSymName());
2576 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
2577 elidedAttrs: {SymbolTable::getSymbolAttrName(),
2578 getAliasTypeAttrName(), getLinkageAttrName(),
2579 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2580 getVisibility_AttrName()});
2581
2582 // Print the trailing type.
2583 p << " : " << getType() << ' ';
2584 // Print the initializer region.
2585 p.printRegion(blocks&: getInitializerRegion(), /*printEntryBlockArgs=*/false);
2586}
2587
2588// operation ::= `llvm.mlir.alias` linkage? visibility?
2589// (`unnamed_addr` | `local_unnamed_addr`)?
2590// `thread_local`? `@` identifier
2591// `(` attribute? `)`
2592// attribute-list? `:` type region
2593//
2594ParseResult AliasOp::parse(OpAsmParser &parser, OperationState &result) {
2595 // Call into common parsing between GlobalOp and AliasOp.
2596 if (parseCommonGlobalAndAlias<AliasOp>(parser, result).failed())
2597 return failure();
2598
2599 StringAttr name;
2600 if (parser.parseSymbolName(result&: name, attrName: getSymNameAttrName(name: result.name),
2601 attrs&: result.attributes))
2602 return failure();
2603
2604 SmallVector<Type, 1> types;
2605 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
2606 parser.parseOptionalColonTypeList(result&: types))
2607 return failure();
2608
2609 if (types.size() > 1)
2610 return parser.emitError(loc: parser.getNameLoc(), message: "expected zero or one type");
2611
2612 Region &initRegion = *result.addRegion();
2613 if (parser.parseRegion(region&: initRegion).failed())
2614 return failure();
2615
2616 result.addAttribute(name: getAliasTypeAttrName(name: result.name),
2617 attr: TypeAttr::get(type: types[0]));
2618 return success();
2619}
2620
2621LogicalResult AliasOp::verify() {
2622 bool validType = isCompatibleOuterType(type: getType())
2623 ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2624 LLVMMetadataType, LLVMLabelType>(Val: getType())
2625 : llvm::isa<PointerElementTypeInterface>(Val: getType());
2626 if (!validType)
2627 return emitOpError(
2628 message: "expects type to be a valid element type for an LLVM global alias");
2629
2630 // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2631 switch (getLinkage()) {
2632 case Linkage::External:
2633 case Linkage::Internal:
2634 case Linkage::Private:
2635 case Linkage::Weak:
2636 case Linkage::WeakODR:
2637 case Linkage::Linkonce:
2638 case Linkage::LinkonceODR:
2639 case Linkage::AvailableExternally:
2640 break;
2641 default:
2642 return emitOpError()
2643 << "'" << stringifyLinkage(getLinkage())
2644 << "' linkage not supported in aliases, available options: private, "
2645 "internal, linkonce, weak, linkonce_odr, weak_odr, external or "
2646 "available_externally";
2647 }
2648
2649 return success();
2650}
2651
2652LogicalResult AliasOp::verifyRegions() {
2653 Block &b = getInitializerBlock();
2654 auto ret = cast<ReturnOp>(Val: b.getTerminator());
2655 if (ret.getNumOperands() == 0 ||
2656 !isa<LLVM::LLVMPointerType>(Val: ret.getOperand(i: 0).getType()))
2657 return emitOpError(message: "initializer region must always return a pointer");
2658
2659 for (Operation &op : b) {
2660 auto iface = dyn_cast<MemoryEffectOpInterface>(Val&: op);
2661 if (!iface || !iface.hasNoEffect())
2662 return op.emitError()
2663 << "ops with side effects are not allowed in alias initializers";
2664 }
2665
2666 return success();
2667}
2668
2669unsigned AliasOp::getAddrSpace() {
2670 Block &initializer = getInitializerBlock();
2671 auto ret = cast<ReturnOp>(Val: initializer.getTerminator());
2672 auto ptrTy = cast<LLVMPointerType>(Val: ret.getOperand(i: 0).getType());
2673 return ptrTy.getAddressSpace();
2674}
2675
2676//===----------------------------------------------------------------------===//
2677// ShuffleVectorOp
2678//===----------------------------------------------------------------------===//
2679
2680void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2681 Value v2, DenseI32ArrayAttr mask,
2682 ArrayRef<NamedAttribute> attrs) {
2683 auto containerType = v1.getType();
2684 auto vType = LLVM::getVectorType(
2685 elementType: cast<VectorType>(Val&: containerType).getElementType(), numElements: mask.size(),
2686 isScalable: LLVM::isScalableVectorType(vectorType: containerType));
2687 build(odsBuilder&: builder, odsState&: state, res: vType, v1, v2, mask);
2688 state.addAttributes(newAttributes: attrs);
2689}
2690
2691void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2692 Value v2, ArrayRef<int32_t> mask) {
2693 build(builder, state, v1, v2, mask: builder.getDenseI32ArrayAttr(values: mask));
2694}
2695
2696/// Build the result type of a shuffle vector operation.
2697static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2698 Type &resType, DenseI32ArrayAttr mask) {
2699 if (!LLVM::isCompatibleVectorType(type: v1Type))
2700 return parser.emitError(loc: parser.getCurrentLocation(),
2701 message: "expected an LLVM compatible vector type");
2702 resType =
2703 LLVM::getVectorType(elementType: cast<VectorType>(Val&: v1Type).getElementType(),
2704 numElements: mask.size(), isScalable: LLVM::isScalableVectorType(vectorType: v1Type));
2705 return success();
2706}
2707
2708/// Nothing to do when the result type is inferred.
2709static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2710 Type resType, DenseI32ArrayAttr mask) {}
2711
2712LogicalResult ShuffleVectorOp::verify() {
2713 if (LLVM::isScalableVectorType(vectorType: getV1().getType()) &&
2714 llvm::any_of(Range: getMask(), P: [](int32_t v) { return v != 0; }))
2715 return emitOpError(message: "expected a splat operation for scalable vectors");
2716 return success();
2717}
2718
2719//===----------------------------------------------------------------------===//
2720// Implementations for LLVM::LLVMFuncOp.
2721//===----------------------------------------------------------------------===//
2722
2723// Add the entry block to the function.
2724Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
2725 assert(empty() && "function already has an entry block");
2726 OpBuilder::InsertionGuard g(builder);
2727 Block *entry = builder.createBlock(parent: &getBody());
2728
2729 // FIXME: Allow passing in proper locations for the entry arguments.
2730 LLVMFunctionType type = getFunctionType();
2731 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2732 entry->addArgument(type: type.getParamType(i), loc: getLoc());
2733 return entry;
2734}
2735
2736void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2737 StringRef name, Type type, LLVM::Linkage linkage,
2738 bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
2739 ArrayRef<NamedAttribute> attrs,
2740 ArrayRef<DictionaryAttr> argAttrs,
2741 std::optional<uint64_t> functionEntryCount) {
2742 result.addRegion();
2743 result.addAttribute(name: SymbolTable::getSymbolAttrName(),
2744 attr: builder.getStringAttr(bytes: name));
2745 result.addAttribute(name: getFunctionTypeAttrName(name: result.name),
2746 attr: TypeAttr::get(type));
2747 result.addAttribute(name: getLinkageAttrName(name: result.name),
2748 attr: LinkageAttr::get(context: builder.getContext(), linkage));
2749 result.addAttribute(name: getCConvAttrName(name: result.name),
2750 attr: CConvAttr::get(context: builder.getContext(), CallingConv: cconv));
2751 result.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
2752 if (dsoLocal)
2753 result.addAttribute(name: getDsoLocalAttrName(name: result.name),
2754 attr: builder.getUnitAttr());
2755 if (comdat)
2756 result.addAttribute(name: getComdatAttrName(name: result.name), attr: comdat);
2757 if (functionEntryCount)
2758 result.addAttribute(name: getFunctionEntryCountAttrName(name: result.name),
2759 attr: builder.getI64IntegerAttr(value: functionEntryCount.value()));
2760 if (argAttrs.empty())
2761 return;
2762
2763 assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
2764 "expected as many argument attribute lists as arguments");
2765 call_interface_impl::addArgAndResultAttrs(
2766 builder, result, argAttrs, /*resultAttrs=*/{},
2767 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
2768}
2769
2770// Builds an LLVM function type from the given lists of input and output types.
2771// Returns a null type if any of the types provided are non-LLVM types, or if
2772// there is more than one output type.
2773static Type
2774buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
2775 ArrayRef<Type> outputs,
2776 function_interface_impl::VariadicFlag variadicFlag) {
2777 Builder &b = parser.getBuilder();
2778 if (outputs.size() > 1) {
2779 parser.emitError(loc, message: "failed to construct function type: expected zero or "
2780 "one function result");
2781 return {};
2782 }
2783
2784 // Convert inputs to LLVM types, exit early on error.
2785 SmallVector<Type, 4> llvmInputs;
2786 for (auto t : inputs) {
2787 if (!isCompatibleType(type: t)) {
2788 parser.emitError(loc, message: "failed to construct function type: expected LLVM "
2789 "type for function arguments");
2790 return {};
2791 }
2792 llvmInputs.push_back(Elt: t);
2793 }
2794
2795 // No output is denoted as "void" in LLVM type system.
2796 Type llvmOutput =
2797 outputs.empty() ? LLVMVoidType::get(ctx: b.getContext()) : outputs.front();
2798 if (!isCompatibleType(type: llvmOutput)) {
2799 parser.emitError(loc, message: "failed to construct function type: expected LLVM "
2800 "type for function results")
2801 << llvmOutput;
2802 return {};
2803 }
2804 return LLVMFunctionType::get(result: llvmOutput, arguments: llvmInputs,
2805 isVarArg: variadicFlag.isVariadic());
2806}
2807
2808// Parses an LLVM function.
2809//
2810// operation ::= `llvm.func` linkage? cconv? function-signature
2811// (`comdat(` symbol-ref-id `)`)?
2812// function-attributes?
2813// function-body
2814//
2815ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2816 // Default to external linkage if no keyword is provided.
2817 result.addAttribute(name: getLinkageAttrName(name: result.name),
2818 attr: LinkageAttr::get(context: parser.getContext(),
2819 linkage: parseOptionalLLVMKeyword<Linkage>(
2820 parser, defaultValue: LLVM::Linkage::External)));
2821
2822 // Parse optional visibility, default to Default.
2823 result.addAttribute(name: getVisibility_AttrName(name: result.name),
2824 attr: parser.getBuilder().getI64IntegerAttr(
2825 value: parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2826 parser, defaultValue: LLVM::Visibility::Default)));
2827
2828 // Parse optional UnnamedAddr, default to None.
2829 result.addAttribute(name: getUnnamedAddrAttrName(name: result.name),
2830 attr: parser.getBuilder().getI64IntegerAttr(
2831 value: parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2832 parser, defaultValue: LLVM::UnnamedAddr::None)));
2833
2834 // Default to C Calling Convention if no keyword is provided.
2835 result.addAttribute(
2836 name: getCConvAttrName(name: result.name),
2837 attr: CConvAttr::get(context: parser.getContext(),
2838 CallingConv: parseOptionalLLVMKeyword<CConv>(parser, defaultValue: LLVM::CConv::C)));
2839
2840 StringAttr nameAttr;
2841 SmallVector<OpAsmParser::Argument> entryArgs;
2842 SmallVector<DictionaryAttr> resultAttrs;
2843 SmallVector<Type> resultTypes;
2844 bool isVariadic;
2845
2846 auto signatureLocation = parser.getCurrentLocation();
2847 if (parser.parseSymbolName(result&: nameAttr, attrName: SymbolTable::getSymbolAttrName(),
2848 attrs&: result.attributes) ||
2849 function_interface_impl::parseFunctionSignatureWithArguments(
2850 parser, /*allowVariadic=*/true, arguments&: entryArgs, isVariadic, resultTypes,
2851 resultAttrs))
2852 return failure();
2853
2854 SmallVector<Type> argTypes;
2855 for (auto &arg : entryArgs)
2856 argTypes.push_back(Elt: arg.type);
2857 auto type =
2858 buildLLVMFunctionType(parser, loc: signatureLocation, inputs: argTypes, outputs: resultTypes,
2859 variadicFlag: function_interface_impl::VariadicFlag(isVariadic));
2860 if (!type)
2861 return failure();
2862 result.addAttribute(name: getFunctionTypeAttrName(name: result.name),
2863 attr: TypeAttr::get(type));
2864
2865 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "vscale_range"))) {
2866 int64_t minRange, maxRange;
2867 if (parser.parseLParen() || parser.parseInteger(result&: minRange) ||
2868 parser.parseComma() || parser.parseInteger(result&: maxRange) ||
2869 parser.parseRParen())
2870 return failure();
2871 auto intTy = IntegerType::get(context: parser.getContext(), width: 32);
2872 result.addAttribute(
2873 name: getVscaleRangeAttrName(name: result.name),
2874 attr: LLVM::VScaleRangeAttr::get(context: parser.getContext(),
2875 minRange: IntegerAttr::get(type: intTy, value: minRange),
2876 maxRange: IntegerAttr::get(type: intTy, value: maxRange)));
2877 }
2878 // Parse the optional comdat selector.
2879 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "comdat"))) {
2880 SymbolRefAttr comdat;
2881 if (parser.parseLParen() || parser.parseAttribute(result&: comdat) ||
2882 parser.parseRParen())
2883 return failure();
2884
2885 result.addAttribute(name: getComdatAttrName(name: result.name), attr: comdat);
2886 }
2887
2888 if (failed(Result: parser.parseOptionalAttrDictWithKeyword(result&: result.attributes)))
2889 return failure();
2890 call_interface_impl::addArgAndResultAttrs(
2891 builder&: parser.getBuilder(), result, args: entryArgs, resultAttrs,
2892 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
2893
2894 auto *body = result.addRegion();
2895 OptionalParseResult parseResult =
2896 parser.parseOptionalRegion(region&: *body, arguments: entryArgs);
2897 return failure(IsFailure: parseResult.has_value() && failed(Result: *parseResult));
2898}
2899
2900// Print the LLVMFuncOp. Collects argument and result types and passes them to
2901// helper functions. Drops "void" result since it cannot be parsed back. Skips
2902// the external linkage since it is the default value.
2903void LLVMFuncOp::print(OpAsmPrinter &p) {
2904 p << ' ';
2905 if (getLinkage() != LLVM::Linkage::External)
2906 p << stringifyLinkage(getLinkage()) << ' ';
2907 StringRef visibility = stringifyVisibility(getVisibility_());
2908 if (!visibility.empty())
2909 p << visibility << ' ';
2910 if (auto unnamedAddr = getUnnamedAddr()) {
2911 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2912 if (!str.empty())
2913 p << str << ' ';
2914 }
2915 if (getCConv() != LLVM::CConv::C)
2916 p << stringifyCConv(getCConv()) << ' ';
2917
2918 p.printSymbolName(symbolRef: getName());
2919
2920 LLVMFunctionType fnType = getFunctionType();
2921 SmallVector<Type, 8> argTypes;
2922 SmallVector<Type, 1> resTypes;
2923 argTypes.reserve(N: fnType.getNumParams());
2924 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2925 argTypes.push_back(Elt: fnType.getParamType(i));
2926
2927 Type returnType = fnType.getReturnType();
2928 if (!llvm::isa<LLVMVoidType>(Val: returnType))
2929 resTypes.push_back(Elt: returnType);
2930
2931 function_interface_impl::printFunctionSignature(p, op: *this, argTypes,
2932 isVariadic: isVarArg(), resultTypes: resTypes);
2933
2934 // Print vscale range if present
2935 if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
2936 p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
2937 << vscale->getMaxRange().getInt() << ')';
2938
2939 // Print the optional comdat selector.
2940 if (auto comdat = getComdat())
2941 p << " comdat(" << *comdat << ')';
2942
2943 function_interface_impl::printFunctionAttributes(
2944 p, op: *this,
2945 elided: {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2946 getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
2947 getComdatAttrName(), getUnnamedAddrAttrName(),
2948 getVscaleRangeAttrName()});
2949
2950 // Print the body if this is not an external function.
2951 Region &body = getBody();
2952 if (!body.empty()) {
2953 p << ' ';
2954 p.printRegion(blocks&: body, /*printEntryBlockArgs=*/false,
2955 /*printBlockTerminators=*/true);
2956 }
2957}
2958
2959// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2960// - functions don't have 'common' linkage
2961// - external functions have 'external' or 'extern_weak' linkage;
2962// - vararg is (currently) only supported for external functions;
2963LogicalResult LLVMFuncOp::verify() {
2964 if (getLinkage() == LLVM::Linkage::Common)
2965 return emitOpError() << "functions cannot have '"
2966 << stringifyLinkage(LLVM::Linkage::Common)
2967 << "' linkage";
2968
2969 if (failed(Result: verifyComdat(op: *this, attr: getComdat())))
2970 return failure();
2971
2972 if (isExternal()) {
2973 if (getLinkage() != LLVM::Linkage::External &&
2974 getLinkage() != LLVM::Linkage::ExternWeak)
2975 return emitOpError() << "external functions must have '"
2976 << stringifyLinkage(LLVM::Linkage::External)
2977 << "' or '"
2978 << stringifyLinkage(LLVM::Linkage::ExternWeak)
2979 << "' linkage";
2980 return success();
2981 }
2982
2983 // In LLVM IR, these attributes are composed by convention, not by design.
2984 if (isNoInline() && isAlwaysInline())
2985 return emitError(message: "no_inline and always_inline attributes are incompatible");
2986
2987 if (isOptimizeNone() && !isNoInline())
2988 return emitOpError(message: "with optimize_none must also be no_inline");
2989
2990 Type landingpadResultTy;
2991 StringRef diagnosticMessage;
2992 bool isLandingpadTypeConsistent =
2993 !walk(callback: [&](Operation *op) {
2994 const auto checkType = [&](Type type, StringRef errorMessage) {
2995 if (!landingpadResultTy) {
2996 landingpadResultTy = type;
2997 return WalkResult::advance();
2998 }
2999 if (landingpadResultTy != type) {
3000 diagnosticMessage = errorMessage;
3001 return WalkResult::interrupt();
3002 }
3003 return WalkResult::advance();
3004 };
3005 return TypeSwitch<Operation *, WalkResult>(op)
3006 .Case<LandingpadOp>(caseFn: [&](auto landingpad) {
3007 constexpr StringLiteral errorMessage =
3008 "'llvm.landingpad' should have a consistent result type "
3009 "inside a function";
3010 return checkType(landingpad.getType(), errorMessage);
3011 })
3012 .Case<ResumeOp>(caseFn: [&](auto resume) {
3013 constexpr StringLiteral errorMessage =
3014 "'llvm.resume' should have a consistent input type inside a "
3015 "function";
3016 return checkType(resume.getValue().getType(), errorMessage);
3017 })
3018 .Default(defaultFn: [](auto) { return WalkResult::skip(); });
3019 }).wasInterrupted();
3020 if (!isLandingpadTypeConsistent) {
3021 assert(!diagnosticMessage.empty() &&
3022 "Expecting a non-empty diagnostic message");
3023 return emitError(message: diagnosticMessage);
3024 }
3025
3026 if (failed(Result: verifyBlockTags(funcOp: *this)))
3027 return failure();
3028
3029 return success();
3030}
3031
3032/// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3033/// - entry block arguments are of LLVM types.
3034LogicalResult LLVMFuncOp::verifyRegions() {
3035 if (isExternal())
3036 return success();
3037
3038 unsigned numArguments = getFunctionType().getNumParams();
3039 Block &entryBlock = front();
3040 for (unsigned i = 0; i < numArguments; ++i) {
3041 Type argType = entryBlock.getArgument(i).getType();
3042 if (!isCompatibleType(type: argType))
3043 return emitOpError(message: "entry block argument #")
3044 << i << " is not of LLVM type";
3045 }
3046
3047 return success();
3048}
3049
3050Region *LLVMFuncOp::getCallableRegion() {
3051 if (isExternal())
3052 return nullptr;
3053 return &getBody();
3054}
3055
3056//===----------------------------------------------------------------------===//
3057// UndefOp.
3058//===----------------------------------------------------------------------===//
3059
3060/// Fold an undef operation to a dedicated undef attribute.
3061OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) {
3062 return LLVM::UndefAttr::get(ctx: getContext());
3063}
3064
3065//===----------------------------------------------------------------------===//
3066// PoisonOp.
3067//===----------------------------------------------------------------------===//
3068
3069/// Fold a poison operation to a dedicated poison attribute.
3070OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) {
3071 return LLVM::PoisonAttr::get(ctx: getContext());
3072}
3073
3074//===----------------------------------------------------------------------===//
3075// ZeroOp.
3076//===----------------------------------------------------------------------===//
3077
3078LogicalResult LLVM::ZeroOp::verify() {
3079 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(Val: getType()))
3080 if (!targetExtType.hasProperty(Prop: LLVM::LLVMTargetExtType::HasZeroInit))
3081 return emitOpError()
3082 << "target extension type does not support zero-initializer";
3083
3084 return success();
3085}
3086
3087/// Fold a zero operation to a builtin zero attribute when possible and fall
3088/// back to a dedicated zero attribute.
3089OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
3090 OpFoldResult result = Builder(getContext()).getZeroAttr(type: getType());
3091 if (result)
3092 return result;
3093 return LLVM::ZeroAttr::get(ctx: getContext());
3094}
3095
3096//===----------------------------------------------------------------------===//
3097// ConstantOp.
3098//===----------------------------------------------------------------------===//
3099
3100/// Compute the total number of elements in the given type, also taking into
3101/// account nested types. Supported types are `VectorType` and `LLVMArrayType`.
3102/// Everything else is treated as a scalar.
3103static int64_t getNumElements(Type t) {
3104 if (auto vecType = dyn_cast<VectorType>(Val&: t)) {
3105 assert(!vecType.isScalable() &&
3106 "number of elements of a scalable vector type is unknown");
3107 return vecType.getNumElements() * getNumElements(t: vecType.getElementType());
3108 }
3109 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(Val&: t))
3110 return arrayType.getNumElements() *
3111 getNumElements(t: arrayType.getElementType());
3112 return 1;
3113}
3114
3115/// Check if the given type is a scalable vector type or a vector/array type
3116/// that contains a nested scalable vector type.
3117static bool hasScalableVectorType(Type t) {
3118 if (auto vecType = dyn_cast<VectorType>(Val&: t)) {
3119 if (vecType.isScalable())
3120 return true;
3121 return hasScalableVectorType(t: vecType.getElementType());
3122 }
3123 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(Val&: t))
3124 return hasScalableVectorType(t: arrayType.getElementType());
3125 return false;
3126}
3127
3128/// Verifies the constant array represented by `arrayAttr` matches the provided
3129/// `arrayType`.
3130static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op,
3131 LLVM::LLVMArrayType arrayType,
3132 ArrayAttr arrayAttr, int dim) {
3133 if (arrayType.getNumElements() != arrayAttr.size())
3134 return op.emitOpError()
3135 << "array attribute size does not match array type size in "
3136 "dimension "
3137 << dim << ": " << arrayAttr.size() << " vs. "
3138 << arrayType.getNumElements();
3139
3140 llvm::DenseSet<Attribute> elementsVerified;
3141
3142 // Recursively verify sub-dimensions for multidimensional arrays.
3143 if (auto subArrayType =
3144 dyn_cast<LLVM::LLVMArrayType>(Val: arrayType.getElementType())) {
3145 for (auto [idx, elementAttr] : llvm::enumerate(First&: arrayAttr))
3146 if (elementsVerified.insert(V: elementAttr).second) {
3147 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(Val: elementAttr))
3148 continue;
3149 auto subArrayAttr = dyn_cast<ArrayAttr>(Val: elementAttr);
3150 if (!subArrayAttr)
3151 return op.emitOpError()
3152 << "nested attribute for sub-array in dimension " << dim
3153 << " at index " << idx
3154 << " must be a zero, or undef, or array attribute";
3155 if (failed(Result: verifyStructArrayConstant(op, arrayType: subArrayType, arrayAttr: subArrayAttr,
3156 dim: dim + 1)))
3157 return failure();
3158 }
3159 return success();
3160 }
3161
3162 // Forbid usages of ArrayAttr for simple array types that should use
3163 // DenseElementsAttr instead. Note that there would be a use case for such
3164 // array types when one element value is obtained via a ptr-to-int conversion
3165 // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
3166 // user needs this so far, and it seems better to avoid people misusing the
3167 // ArrayAttr for simple types.
3168 auto structType = dyn_cast<LLVM::LLVMStructType>(Val: arrayType.getElementType());
3169 if (!structType)
3170 return op.emitOpError() << "for array with an array attribute must have a "
3171 "struct element type";
3172
3173 // Shallow verification that leaf attributes are appropriate as struct initial
3174 // value.
3175 size_t numStructElements = structType.getBody().size();
3176 for (auto [idx, elementAttr] : llvm::enumerate(First&: arrayAttr)) {
3177 if (elementsVerified.insert(V: elementAttr).second) {
3178 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(Val: elementAttr))
3179 continue;
3180 auto subArrayAttr = dyn_cast<ArrayAttr>(Val: elementAttr);
3181 if (!subArrayAttr)
3182 return op.emitOpError()
3183 << "nested attribute for struct element at index " << idx
3184 << " must be a zero, or undef, or array attribute";
3185 if (subArrayAttr.size() != numStructElements)
3186 return op.emitOpError()
3187 << "nested array attribute size for struct element at index "
3188 << idx << " must match struct size: " << subArrayAttr.size()
3189 << " vs. " << numStructElements;
3190 }
3191 }
3192
3193 return success();
3194}
3195
3196LogicalResult LLVM::ConstantOp::verify() {
3197 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(Val: getValue())) {
3198 auto arrayType = llvm::dyn_cast<LLVMArrayType>(Val: getType());
3199 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
3200 !arrayType.getElementType().isInteger(width: 8)) {
3201 return emitOpError() << "expected array type of "
3202 << sAttr.getValue().size()
3203 << " i8 elements for the string constant";
3204 }
3205 return success();
3206 }
3207 if (auto structType = dyn_cast<LLVMStructType>(Val: getType())) {
3208 auto arrayAttr = dyn_cast<ArrayAttr>(Val: getValue());
3209 if (!arrayAttr) {
3210 return emitOpError() << "expected array attribute for a struct constant";
3211 }
3212
3213 ArrayRef<Type> elementTypes = structType.getBody();
3214 if (arrayAttr.size() != elementTypes.size()) {
3215 return emitOpError() << "expected array attribute of size "
3216 << elementTypes.size();
3217 }
3218 for (auto elementTy : elementTypes) {
3219 if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(Val: elementTy)) {
3220 return emitOpError() << "expected struct element types to be floating "
3221 "point type or integer type";
3222 }
3223 }
3224
3225 for (size_t i = 0; i < elementTypes.size(); ++i) {
3226 Attribute element = arrayAttr[i];
3227 if (!isa<IntegerAttr, FloatAttr>(Val: element)) {
3228 return emitOpError()
3229 << "expected struct element attribute types to be floating "
3230 "point type or integer type";
3231 }
3232 auto elementType = cast<TypedAttr>(Val&: element).getType();
3233 if (elementType != elementTypes[i]) {
3234 return emitOpError()
3235 << "struct element at index " << i << " is of wrong type";
3236 }
3237 }
3238
3239 return success();
3240 }
3241 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(Val: getType())) {
3242 return emitOpError() << "does not support target extension type.";
3243 }
3244
3245 // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3246 if (auto intAttr = dyn_cast<IntegerAttr>(Val: getValue())) {
3247 if (!llvm::isa<IntegerType>(Val: getType()))
3248 return emitOpError() << "expected integer type";
3249 } else if (auto floatAttr = dyn_cast<FloatAttr>(Val: getValue())) {
3250 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
3251 unsigned floatWidth = APFloat::getSizeInBits(Sem: sem);
3252 if (auto floatTy = dyn_cast<FloatType>(Val: getType())) {
3253 if (floatTy.getWidth() != floatWidth) {
3254 return emitOpError() << "expected float type of width " << floatWidth;
3255 }
3256 }
3257 // See the comment for getLLVMConstant for more details about why 8-bit
3258 // floats can be represented by integers.
3259 if (isa<IntegerType>(Val: getType()) && !getType().isInteger(width: floatWidth)) {
3260 return emitOpError() << "expected integer type of width " << floatWidth;
3261 }
3262 } else if (isa<ElementsAttr>(Val: getValue())) {
3263 if (hasScalableVectorType(t: getType())) {
3264 // The exact number of elements of a scalable vector is unknown, so we
3265 // allow only splat attributes.
3266 auto splatElementsAttr = dyn_cast<SplatElementsAttr>(Val: getValue());
3267 if (!splatElementsAttr)
3268 return emitOpError()
3269 << "scalable vector type requires a splat attribute";
3270 return success();
3271 }
3272 if (!isa<VectorType, LLVM::LLVMArrayType>(Val: getType()))
3273 return emitOpError() << "expected vector or array type";
3274 // The number of elements of the attribute and the type must match.
3275 if (auto elementsAttr = dyn_cast<ElementsAttr>(Val: getValue())) {
3276 int64_t attrNumElements = elementsAttr.getNumElements();
3277 if (getNumElements(t: getType()) != attrNumElements)
3278 return emitOpError()
3279 << "type and attribute have a different number of elements: "
3280 << getNumElements(t: getType()) << " vs. " << attrNumElements;
3281 }
3282 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(Val: getValue())) {
3283 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(Val: getType());
3284 if (!arrayType)
3285 return emitOpError() << "expected array type";
3286 // When the attribute is an ArrayAttr, check that its nesting matches the
3287 // corresponding ArrayType or VectorType nesting.
3288 return verifyStructArrayConstant(op: *this, arrayType, arrayAttr, /*dim=*/0);
3289 } else {
3290 return emitOpError()
3291 << "only supports integer, float, string or elements attributes";
3292 }
3293
3294 return success();
3295}
3296
3297bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
3298 // The value's type must be the same as the provided type.
3299 auto typedAttr = dyn_cast<TypedAttr>(Val&: value);
3300 if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
3301 return false;
3302 // The value's type must be an LLVM compatible type.
3303 if (!isCompatibleType(type))
3304 return false;
3305 // TODO: Add support for additional attributes kinds once needed.
3306 return isa<IntegerAttr, FloatAttr, ElementsAttr>(Val: value);
3307}
3308
3309ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
3310 Type type, Location loc) {
3311 if (isBuildableWith(value, type))
3312 return builder.create<LLVM::ConstantOp>(location: loc, args: cast<TypedAttr>(Val&: value));
3313 return nullptr;
3314}
3315
3316// Constant op constant-folds to its value.
3317OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
3318
3319//===----------------------------------------------------------------------===//
3320// AtomicRMWOp
3321//===----------------------------------------------------------------------===//
3322
3323void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
3324 AtomicBinOp binOp, Value ptr, Value val,
3325 AtomicOrdering ordering, StringRef syncscope,
3326 unsigned alignment, bool isVolatile) {
3327 build(odsBuilder&: builder, odsState&: state, res: val.getType(), bin_op: binOp, ptr, val, ordering,
3328 syncscope: !syncscope.empty() ? builder.getStringAttr(bytes: syncscope) : nullptr,
3329 alignment: alignment ? builder.getI64IntegerAttr(value: alignment) : nullptr, volatile_: isVolatile,
3330 /*access_groups=*/nullptr,
3331 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3332}
3333
3334LogicalResult AtomicRMWOp::verify() {
3335 auto valType = getVal().getType();
3336 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
3337 getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax ||
3338 getBinOp() == AtomicBinOp::fminimum ||
3339 getBinOp() == AtomicBinOp::fmaximum) {
3340 if (isCompatibleVectorType(type: valType)) {
3341 if (isScalableVectorType(vectorType: valType))
3342 return emitOpError(message: "expected LLVM IR fixed vector type");
3343 Type elemType = llvm::cast<VectorType>(Val&: valType).getElementType();
3344 if (!isCompatibleFloatingPointType(type: elemType))
3345 return emitOpError(
3346 message: "expected LLVM IR floating point type for vector element");
3347 } else if (!isCompatibleFloatingPointType(type: valType)) {
3348 return emitOpError(message: "expected LLVM IR floating point type");
3349 }
3350 } else if (getBinOp() == AtomicBinOp::xchg) {
3351 DataLayout dataLayout = DataLayout::closest(op: *this);
3352 if (!isTypeCompatibleWithAtomicOp(type: valType, dataLayout))
3353 return emitOpError(message: "unexpected LLVM IR type for 'xchg' bin_op");
3354 } else {
3355 auto intType = llvm::dyn_cast<IntegerType>(Val&: valType);
3356 unsigned intBitWidth = intType ? intType.getWidth() : 0;
3357 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
3358 intBitWidth != 64)
3359 return emitOpError(message: "expected LLVM IR integer type");
3360 }
3361
3362 if (static_cast<unsigned>(getOrdering()) <
3363 static_cast<unsigned>(AtomicOrdering::monotonic))
3364 return emitOpError() << "expected at least '"
3365 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
3366 << "' ordering";
3367
3368 return success();
3369}
3370
3371//===----------------------------------------------------------------------===//
3372// AtomicCmpXchgOp
3373//===----------------------------------------------------------------------===//
3374
3375/// Returns an LLVM struct type that contains a value type and a boolean type.
3376static LLVMStructType getValAndBoolStructType(Type valType) {
3377 auto boolType = IntegerType::get(context: valType.getContext(), width: 1);
3378 return LLVMStructType::getLiteral(context: valType.getContext(), types: {valType, boolType});
3379}
3380
3381void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
3382 Value ptr, Value cmp, Value val,
3383 AtomicOrdering successOrdering,
3384 AtomicOrdering failureOrdering, StringRef syncscope,
3385 unsigned alignment, bool isWeak, bool isVolatile) {
3386 build(odsBuilder&: builder, odsState&: state, res: getValAndBoolStructType(valType: val.getType()), ptr, cmp, val,
3387 success_ordering: successOrdering, failure_ordering: failureOrdering,
3388 syncscope: !syncscope.empty() ? builder.getStringAttr(bytes: syncscope) : nullptr,
3389 alignment: alignment ? builder.getI64IntegerAttr(value: alignment) : nullptr, weak: isWeak,
3390 volatile_: isVolatile, /*access_groups=*/nullptr,
3391 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3392}
3393
3394LogicalResult AtomicCmpXchgOp::verify() {
3395 auto ptrType = llvm::cast<LLVM::LLVMPointerType>(Val: getPtr().getType());
3396 if (!ptrType)
3397 return emitOpError(message: "expected LLVM IR pointer type for operand #0");
3398 auto valType = getVal().getType();
3399 DataLayout dataLayout = DataLayout::closest(op: *this);
3400 if (!isTypeCompatibleWithAtomicOp(type: valType, dataLayout))
3401 return emitOpError(message: "unexpected LLVM IR type");
3402 if (getSuccessOrdering() < AtomicOrdering::monotonic ||
3403 getFailureOrdering() < AtomicOrdering::monotonic)
3404 return emitOpError(message: "ordering must be at least 'monotonic'");
3405 if (getFailureOrdering() == AtomicOrdering::release ||
3406 getFailureOrdering() == AtomicOrdering::acq_rel)
3407 return emitOpError(message: "failure ordering cannot be 'release' or 'acq_rel'");
3408 return success();
3409}
3410
3411//===----------------------------------------------------------------------===//
3412// FenceOp
3413//===----------------------------------------------------------------------===//
3414
3415void FenceOp::build(OpBuilder &builder, OperationState &state,
3416 AtomicOrdering ordering, StringRef syncscope) {
3417 build(odsBuilder&: builder, odsState&: state, ordering,
3418 syncscope: syncscope.empty() ? nullptr : builder.getStringAttr(bytes: syncscope));
3419}
3420
3421LogicalResult FenceOp::verify() {
3422 if (getOrdering() == AtomicOrdering::not_atomic ||
3423 getOrdering() == AtomicOrdering::unordered ||
3424 getOrdering() == AtomicOrdering::monotonic)
3425 return emitOpError(message: "can be given only acquire, release, acq_rel, "
3426 "and seq_cst orderings");
3427 return success();
3428}
3429
3430//===----------------------------------------------------------------------===//
3431// Verifier for extension ops
3432//===----------------------------------------------------------------------===//
3433
3434/// Verifies that the given extension operation operates on consistent scalars
3435/// or vectors, and that the target width is larger than the input width.
3436template <class ExtOp>
3437static LogicalResult verifyExtOp(ExtOp op) {
3438 IntegerType inputType, outputType;
3439 if (isCompatibleVectorType(op.getArg().getType())) {
3440 if (!isCompatibleVectorType(op.getResult().getType()))
3441 return op.emitError(
3442 "input type is a vector but output type is an integer");
3443 if (getVectorNumElements(op.getArg().getType()) !=
3444 getVectorNumElements(op.getResult().getType()))
3445 return op.emitError("input and output vectors are of incompatible shape");
3446 // Because this is a CastOp, the element of vectors is guaranteed to be an
3447 // integer.
3448 inputType = cast<IntegerType>(
3449 cast<VectorType>(op.getArg().getType()).getElementType());
3450 outputType = cast<IntegerType>(
3451 cast<VectorType>(op.getResult().getType()).getElementType());
3452 } else {
3453 // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
3454 // an integer.
3455 inputType = cast<IntegerType>(op.getArg().getType());
3456 outputType = dyn_cast<IntegerType>(op.getResult().getType());
3457 if (!outputType)
3458 return op.emitError(
3459 "input type is an integer but output type is a vector");
3460 }
3461
3462 if (outputType.getWidth() <= inputType.getWidth())
3463 return op.emitError("integer width of the output type is smaller or "
3464 "equal to the integer width of the input type");
3465 return success();
3466}
3467
3468//===----------------------------------------------------------------------===//
3469// ZExtOp
3470//===----------------------------------------------------------------------===//
3471
3472LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(op: *this); }
3473
3474OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
3475 auto arg = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getArg());
3476 if (!arg)
3477 return {};
3478
3479 size_t targetSize = cast<IntegerType>(Val: getType()).getWidth();
3480 return IntegerAttr::get(type: getType(), value: arg.getValue().zext(width: targetSize));
3481}
3482
3483//===----------------------------------------------------------------------===//
3484// SExtOp
3485//===----------------------------------------------------------------------===//
3486
3487LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(op: *this); }
3488
3489//===----------------------------------------------------------------------===//
3490// Folder and verifier for LLVM::BitcastOp
3491//===----------------------------------------------------------------------===//
3492
3493/// Folds a cast op that can be chained.
3494template <typename T>
3495static OpFoldResult foldChainableCast(T castOp,
3496 typename T::FoldAdaptor adaptor) {
3497 // cast(x : T0, T0) -> x
3498 if (castOp.getArg().getType() == castOp.getType())
3499 return castOp.getArg();
3500 if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
3501 // cast(cast(x : T0, T1), T0) -> x
3502 if (prev.getArg().getType() == castOp.getType())
3503 return prev.getArg();
3504 // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
3505 castOp.getArgMutable().set(prev.getArg());
3506 return Value{castOp};
3507 }
3508 return {};
3509}
3510
3511OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
3512 return foldChainableCast(castOp: *this, adaptor);
3513}
3514
3515LogicalResult LLVM::BitcastOp::verify() {
3516 auto resultType = llvm::dyn_cast<LLVMPointerType>(
3517 Val: extractVectorElementType(type: getResult().getType()));
3518 auto sourceType = llvm::dyn_cast<LLVMPointerType>(
3519 Val: extractVectorElementType(type: getArg().getType()));
3520
3521 // If one of the types is a pointer (or vector of pointers), then
3522 // both source and result type have to be pointers.
3523 if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
3524 return emitOpError(message: "can only cast pointers from and to pointers");
3525
3526 if (!resultType)
3527 return success();
3528
3529 auto isVector = llvm::IsaPred<VectorType>;
3530
3531 // Due to bitcast requiring both operands to be of the same size, it is not
3532 // possible for only one of the two to be a pointer of vectors.
3533 if (isVector(getResult().getType()) && !isVector(getArg().getType()))
3534 return emitOpError(message: "cannot cast pointer to vector of pointers");
3535
3536 if (!isVector(getResult().getType()) && isVector(getArg().getType()))
3537 return emitOpError(message: "cannot cast vector of pointers to pointer");
3538
3539 // Bitcast cannot cast between pointers of different address spaces.
3540 // 'llvm.addrspacecast' must be used for this purpose instead.
3541 if (resultType.getAddressSpace() != sourceType.getAddressSpace())
3542 return emitOpError(message: "cannot cast pointers of different address spaces, "
3543 "use 'llvm.addrspacecast' instead");
3544
3545 return success();
3546}
3547
3548//===----------------------------------------------------------------------===//
3549// Folder for LLVM::AddrSpaceCastOp
3550//===----------------------------------------------------------------------===//
3551
3552OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
3553 return foldChainableCast(castOp: *this, adaptor);
3554}
3555
3556Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); }
3557
3558//===----------------------------------------------------------------------===//
3559// Folder for LLVM::GEPOp
3560//===----------------------------------------------------------------------===//
3561
3562OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
3563 GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
3564 adaptor.getDynamicIndices());
3565
3566 // gep %x:T, 0 -> %x
3567 if (getBase().getType() == getType() && indices.size() == 1)
3568 if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(Val: indices[0]))
3569 if (integer.getValue().isZero())
3570 return getBase();
3571
3572 // Canonicalize any dynamic indices of constant value to constant indices.
3573 bool changed = false;
3574 SmallVector<GEPArg> gepArgs;
3575 for (auto iter : llvm::enumerate(First&: indices)) {
3576 auto integer = llvm::dyn_cast_or_null<IntegerAttr>(Val&: iter.value());
3577 // Constant indices can only be int32_t, so if integer does not fit we
3578 // are forced to keep it dynamic, despite being a constant.
3579 if (!indices.isDynamicIndex(index: iter.index()) || !integer ||
3580 !integer.getValue().isSignedIntN(N: kGEPConstantBitWidth)) {
3581
3582 PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
3583 if (Value val = llvm::dyn_cast_if_present<Value>(Val&: existing))
3584 gepArgs.emplace_back(Args&: val);
3585 else
3586 gepArgs.emplace_back(Args: cast<IntegerAttr>(Val&: existing).getInt());
3587
3588 continue;
3589 }
3590
3591 changed = true;
3592 gepArgs.emplace_back(Args: integer.getInt());
3593 }
3594 if (changed) {
3595 SmallVector<int32_t> rawConstantIndices;
3596 SmallVector<Value> dynamicIndices;
3597 destructureIndices(currType: getElemType(), indices: gepArgs, rawConstantIndices,
3598 dynamicIndices);
3599
3600 getDynamicIndicesMutable().assign(values: dynamicIndices);
3601 setRawConstantIndices(rawConstantIndices);
3602 return Value{*this};
3603 }
3604
3605 return {};
3606}
3607
3608Value LLVM::GEPOp::getViewSource() { return getBase(); }
3609
3610//===----------------------------------------------------------------------===//
3611// ShlOp
3612//===----------------------------------------------------------------------===//
3613
3614OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
3615 auto rhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getRhs());
3616 if (!rhs)
3617 return {};
3618
3619 if (rhs.getValue().getZExtValue() >=
3620 getLhs().getType().getIntOrFloatBitWidth())
3621 return {}; // TODO: Fold into poison.
3622
3623 auto lhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getLhs());
3624 if (!lhs)
3625 return {};
3626
3627 return IntegerAttr::get(type: getType(), value: lhs.getValue().shl(ShiftAmt: rhs.getValue()));
3628}
3629
3630//===----------------------------------------------------------------------===//
3631// OrOp
3632//===----------------------------------------------------------------------===//
3633
3634OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
3635 auto lhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getLhs());
3636 if (!lhs)
3637 return {};
3638
3639 auto rhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getRhs());
3640 if (!rhs)
3641 return {};
3642
3643 return IntegerAttr::get(type: getType(), value: lhs.getValue() | rhs.getValue());
3644}
3645
3646//===----------------------------------------------------------------------===//
3647// CallIntrinsicOp
3648//===----------------------------------------------------------------------===//
3649
3650LogicalResult CallIntrinsicOp::verify() {
3651 if (!getIntrin().starts_with(Prefix: "llvm."))
3652 return emitOpError() << "intrinsic name must start with 'llvm.'";
3653 if (failed(Result: verifyOperandBundles(op&: *this)))
3654 return failure();
3655 return success();
3656}
3657
3658void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3659 mlir::StringAttr intrin, mlir::ValueRange args) {
3660 build(odsBuilder&: builder, odsState&: state, /*resultTypes=*/TypeRange{}, intrin, args,
3661 fastmathFlags: FastmathFlagsAttr{},
3662 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3663 /*res_attrs=*/{});
3664}
3665
3666void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3667 mlir::StringAttr intrin, mlir::ValueRange args,
3668 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3669 build(odsBuilder&: builder, odsState&: state, /*resultTypes=*/TypeRange{}, intrin, args,
3670 fastmathFlags: fastMathFlags,
3671 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3672 /*res_attrs=*/{});
3673}
3674
3675void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3676 mlir::Type resultType, mlir::StringAttr intrin,
3677 mlir::ValueRange args) {
3678 build(odsBuilder&: builder, odsState&: state, results: {resultType}, intrin, args, fastmathFlags: FastmathFlagsAttr{},
3679 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3680 /*res_attrs=*/{});
3681}
3682
3683void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3684 mlir::TypeRange resultTypes,
3685 mlir::StringAttr intrin, mlir::ValueRange args,
3686 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3687 build(odsBuilder&: builder, odsState&: state, resultTypes, intrin, args, fastmathFlags: fastMathFlags,
3688 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3689 /*res_attrs=*/{});
3690}
3691
3692ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
3693 OperationState &result) {
3694 StringAttr intrinAttr;
3695 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3696 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
3697 SmallVector<SmallVector<Type>> opBundleOperandTypes;
3698 ArrayAttr opBundleTags;
3699
3700 // Parse intrinsic name.
3701 if (parser.parseCustomAttributeWithFallback(
3702 result&: intrinAttr, type: parser.getBuilder().getType<NoneType>()))
3703 return failure();
3704 result.addAttribute(name: CallIntrinsicOp::getIntrinAttrName(name: result.name),
3705 attr: intrinAttr);
3706
3707 if (parser.parseLParen())
3708 return failure();
3709
3710 // Parse the function arguments.
3711 if (parser.parseOperandList(result&: operands))
3712 return mlir::failure();
3713
3714 if (parser.parseRParen())
3715 return mlir::failure();
3716
3717 // Handle bundles.
3718 SMLoc opBundlesLoc = parser.getCurrentLocation();
3719 if (std::optional<ParseResult> result = parseOpBundles(
3720 p&: parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
3721 result && failed(Result: *result))
3722 return failure();
3723 if (opBundleTags && !opBundleTags.empty())
3724 result.addAttribute(
3725 name: CallIntrinsicOp::getOpBundleTagsAttrName(name: result.name).getValue(),
3726 attr: opBundleTags);
3727
3728 if (parser.parseOptionalAttrDict(result&: result.attributes))
3729 return mlir::failure();
3730
3731 SmallVector<DictionaryAttr> argAttrs;
3732 SmallVector<DictionaryAttr> resultAttrs;
3733 if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
3734 operands, argAttrs, resultAttrs))
3735 return failure();
3736 call_interface_impl::addArgAndResultAttrs(
3737 builder&: parser.getBuilder(), result, argAttrs, resultAttrs,
3738 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
3739
3740 if (resolveOpBundleOperands(parser, loc: opBundlesLoc, state&: result, opBundleOperands,
3741 opBundleOperandTypes,
3742 opBundleSizesAttrName: getOpBundleSizesAttrName(name: result.name)))
3743 return failure();
3744
3745 int32_t numOpBundleOperands = 0;
3746 for (const auto &operands : opBundleOperands)
3747 numOpBundleOperands += operands.size();
3748
3749 result.addAttribute(
3750 name: CallIntrinsicOp::getOperandSegmentSizeAttr(),
3751 attr: parser.getBuilder().getDenseI32ArrayAttr(
3752 values: {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
3753
3754 return mlir::success();
3755}
3756
3757void CallIntrinsicOp::print(OpAsmPrinter &p) {
3758 p << ' ';
3759 p.printAttributeWithoutType(attr: getIntrinAttr());
3760
3761 OperandRange args = getArgs();
3762 p << "(" << args << ")";
3763
3764 // Operand bundles.
3765 if (!getOpBundleOperands().empty()) {
3766 p << ' ';
3767 printOpBundles(p, op: *this, opBundleOperands: getOpBundleOperands(),
3768 opBundleOperandTypes: getOpBundleOperands().getTypes(), opBundleTags: getOpBundleTagsAttr());
3769 }
3770
3771 p.printOptionalAttrDict(attrs: processFMFAttr(attrs: (*this)->getAttrs()),
3772 elidedAttrs: {getOperandSegmentSizesAttrName(),
3773 getOpBundleSizesAttrName(), getIntrinAttrName(),
3774 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
3775 getResAttrsAttrName()});
3776
3777 p << " : ";
3778
3779 // Reconstruct the MLIR function type from operand and result types.
3780 call_interface_impl::printFunctionSignature(
3781 p, argTypes: args.getTypes(), argAttrs: getArgAttrsAttr(),
3782 /*isVariadic=*/false, resultTypes: getResultTypes(), resultAttrs: getResAttrsAttr());
3783}
3784
3785//===----------------------------------------------------------------------===//
3786// LinkerOptionsOp
3787//===----------------------------------------------------------------------===//
3788
3789LogicalResult LinkerOptionsOp::verify() {
3790 if (mlir::Operation *parentOp = (*this)->getParentOp();
3791 parentOp && !satisfiesLLVMModule(op: parentOp))
3792 return emitOpError(message: "must appear at the module level");
3793 return success();
3794}
3795
3796//===----------------------------------------------------------------------===//
3797// ModuleFlagsOp
3798//===----------------------------------------------------------------------===//
3799
3800LogicalResult ModuleFlagsOp::verify() {
3801 if (Operation *parentOp = (*this)->getParentOp();
3802 parentOp && !satisfiesLLVMModule(op: parentOp))
3803 return emitOpError(message: "must appear at the module level");
3804 for (Attribute flag : getFlags())
3805 if (!isa<ModuleFlagAttr>(Val: flag))
3806 return emitOpError(message: "expected a module flag attribute");
3807 return success();
3808}
3809
3810//===----------------------------------------------------------------------===//
3811// InlineAsmOp
3812//===----------------------------------------------------------------------===//
3813
3814void InlineAsmOp::getEffects(
3815 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3816 &effects) {
3817 if (getHasSideEffects()) {
3818 effects.emplace_back(Args: MemoryEffects::Write::get());
3819 effects.emplace_back(Args: MemoryEffects::Read::get());
3820 }
3821}
3822
3823//===----------------------------------------------------------------------===//
3824// BlockAddressOp
3825//===----------------------------------------------------------------------===//
3826
3827LogicalResult
3828BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3829 Operation *symbol = symbolTable.lookupSymbolIn(symbolTableOp: parentLLVMModule(op: *this),
3830 name: getBlockAddr().getFunction());
3831 auto function = dyn_cast_or_null<LLVMFuncOp>(Val: symbol);
3832
3833 if (!function)
3834 return emitOpError(message: "must reference a function defined by 'llvm.func'");
3835
3836 return success();
3837}
3838
3839LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
3840 return dyn_cast_or_null<LLVMFuncOp>(Val: symbolTable.lookupSymbolIn(
3841 symbolTableOp: parentLLVMModule(op: *this), name: getBlockAddr().getFunction()));
3842}
3843
3844BlockTagOp BlockAddressOp::getBlockTagOp() {
3845 auto funcOp = dyn_cast<LLVMFuncOp>(Val: mlir::SymbolTable::lookupNearestSymbolFrom(
3846 from: parentLLVMModule(op: *this), symbol: getBlockAddr().getFunction()));
3847 if (!funcOp)
3848 return nullptr;
3849
3850 BlockTagOp blockTagOp = nullptr;
3851 funcOp.walk(callback: [&](LLVM::BlockTagOp labelOp) {
3852 if (labelOp.getTag() == getBlockAddr().getTag()) {
3853 blockTagOp = labelOp;
3854 return WalkResult::interrupt();
3855 }
3856 return WalkResult::advance();
3857 });
3858 return blockTagOp;
3859}
3860
3861LogicalResult BlockAddressOp::verify() {
3862 if (!getBlockTagOp())
3863 return emitOpError(
3864 message: "expects an existing block label target in the referenced function");
3865
3866 return success();
3867}
3868
3869/// Fold a blockaddress operation to a dedicated blockaddress
3870/// attribute.
3871OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
3872
3873//===----------------------------------------------------------------------===//
3874// LLVM::IndirectBrOp
3875//===----------------------------------------------------------------------===//
3876
3877SuccessorOperands IndirectBrOp::getSuccessorOperands(unsigned index) {
3878 assert(index < getNumSuccessors() && "invalid successor index");
3879 return SuccessorOperands(getSuccOperandsMutable()[index]);
3880}
3881
3882void IndirectBrOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3883 Value addr, ArrayRef<ValueRange> succOperands,
3884 BlockRange successors) {
3885 odsState.addOperands(newOperands: addr);
3886 for (ValueRange range : succOperands)
3887 odsState.addOperands(newOperands: range);
3888 SmallVector<int32_t> rangeSegments;
3889 for (ValueRange range : succOperands)
3890 rangeSegments.push_back(Elt: range.size());
3891 odsState.getOrAddProperties<Properties>().indbr_operand_segments =
3892 odsBuilder.getDenseI32ArrayAttr(values: rangeSegments);
3893 odsState.addSuccessors(newSuccessors: successors);
3894}
3895
3896static ParseResult parseIndirectBrOpSucessors(
3897 OpAsmParser &parser, Type &flagType,
3898 SmallVectorImpl<Block *> &succOperandBlocks,
3899 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
3900 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
3901 if (failed(Result: parser.parseCommaSeparatedList(
3902 delimiter: OpAsmParser::Delimiter::Square,
3903 parseElementFn: [&]() {
3904 Block *destination = nullptr;
3905 SmallVector<OpAsmParser::UnresolvedOperand> operands;
3906 SmallVector<Type> operandTypes;
3907
3908 if (parser.parseSuccessor(dest&: destination).failed())
3909 return failure();
3910
3911 if (succeeded(Result: parser.parseOptionalLParen())) {
3912 if (failed(Result: parser.parseOperandList(
3913 result&: operands, delimiter: OpAsmParser::Delimiter::None)) ||
3914 failed(Result: parser.parseColonTypeList(result&: operandTypes)) ||
3915 failed(Result: parser.parseRParen()))
3916 return failure();
3917 }
3918 succOperandBlocks.push_back(Elt: destination);
3919 succOperands.emplace_back(Args&: operands);
3920 succOperandsTypes.emplace_back(Args&: operandTypes);
3921 return success();
3922 },
3923 contextMessage: "successor blocks")))
3924 return failure();
3925 return success();
3926}
3927
3928static void
3929printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType,
3930 SuccessorRange succs, OperandRangeRange succOperands,
3931 const TypeRangeRange &succOperandsTypes) {
3932 p << "[";
3933 llvm::interleave(
3934 c: llvm::zip(t&: succs, u&: succOperands),
3935 each_fn: [&](auto i) {
3936 p.printNewline();
3937 p.printSuccessorAndUseList(successor: std::get<0>(i), succOperands: std::get<1>(i));
3938 },
3939 between_fn: [&] { p << ','; });
3940 if (!succOperands.empty())
3941 p.printNewline();
3942 p << "]";
3943}
3944
3945//===----------------------------------------------------------------------===//
3946// AssumeOp (intrinsic)
3947//===----------------------------------------------------------------------===//
3948
3949void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3950 mlir::Value cond) {
3951 return build(odsBuilder&: builder, odsState&: state, cond, /*op_bundle_operands=*/{},
3952 /*op_bundle_tags=*/ArrayAttr{});
3953}
3954
3955void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3956 Value cond,
3957 ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) {
3958 SmallVector<ValueRange> opBundleOperands;
3959 SmallVector<Attribute> opBundleTags;
3960 opBundleOperands.reserve(N: opBundles.size());
3961 opBundleTags.reserve(N: opBundles.size());
3962
3963 for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) {
3964 opBundleOperands.emplace_back(Args: bundle.inputs());
3965 opBundleTags.push_back(
3966 Elt: StringAttr::get(context: builder.getContext(), bytes: bundle.getTag()));
3967 }
3968
3969 auto opBundleTagsAttr = ArrayAttr::get(context: builder.getContext(), value: opBundleTags);
3970 return build(odsBuilder&: builder, odsState&: state, cond, op_bundle_operands: opBundleOperands, op_bundle_tags: opBundleTagsAttr);
3971}
3972
3973void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3974 Value cond, llvm::StringRef tag, ValueRange args) {
3975 llvm::OperandBundleDefT<Value> opBundle(
3976 tag.str(), SmallVector<Value>(args.begin(), args.end()));
3977 return build(builder, state, cond, opBundles: opBundle);
3978}
3979
3980void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3981 Value cond, AssumeAlignTag, Value ptr, Value align) {
3982 return build(builder, state, cond, tag: "align", args: ValueRange{ptr, align});
3983}
3984
3985void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3986 Value cond, AssumeSeparateStorageTag, Value ptr1,
3987 Value ptr2) {
3988 return build(builder, state, cond, tag: "separate_storage",
3989 args: ValueRange{ptr1, ptr2});
3990}
3991
3992LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(op&: *this); }
3993
3994//===----------------------------------------------------------------------===//
3995// masked_gather (intrinsic)
3996//===----------------------------------------------------------------------===//
3997
3998LogicalResult LLVM::masked_gather::verify() {
3999 auto ptrsVectorType = getPtrs().getType();
4000 Type expectedPtrsVectorType =
4001 LLVM::getVectorType(elementType: extractVectorElementType(type: ptrsVectorType),
4002 numElements: LLVM::getVectorNumElements(type: getRes().getType()));
4003 // Vector of pointers type should match result vector type, other than the
4004 // element type.
4005 if (ptrsVectorType != expectedPtrsVectorType)
4006 return emitOpError(message: "expected operand #1 type to be ")
4007 << expectedPtrsVectorType;
4008 return success();
4009}
4010
4011//===----------------------------------------------------------------------===//
4012// masked_scatter (intrinsic)
4013//===----------------------------------------------------------------------===//
4014
4015LogicalResult LLVM::masked_scatter::verify() {
4016 auto ptrsVectorType = getPtrs().getType();
4017 Type expectedPtrsVectorType =
4018 LLVM::getVectorType(elementType: extractVectorElementType(type: ptrsVectorType),
4019 numElements: LLVM::getVectorNumElements(type: getValue().getType()));
4020 // Vector of pointers type should match value vector type, other than the
4021 // element type.
4022 if (ptrsVectorType != expectedPtrsVectorType)
4023 return emitOpError(message: "expected operand #2 type to be ")
4024 << expectedPtrsVectorType;
4025 return success();
4026}
4027
4028//===----------------------------------------------------------------------===//
4029// InlineAsmOp
4030//===----------------------------------------------------------------------===//
4031
4032LogicalResult InlineAsmOp::verify() {
4033 if (!getTailCallKindAttr())
4034 return success();
4035
4036 if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
4037 return emitOpError(
4038 message: "tail call kind 'musttail' is not supported by this operation");
4039
4040 return success();
4041}
4042
4043//===----------------------------------------------------------------------===//
4044// LLVMDialect initialization, type parsing, and registration.
4045//===----------------------------------------------------------------------===//
4046
4047void LLVMDialect::initialize() {
4048 registerAttributes();
4049
4050 // clang-format off
4051 addTypes<LLVMVoidType,
4052 LLVMTokenType,
4053 LLVMLabelType,
4054 LLVMMetadataType>();
4055 // clang-format on
4056 registerTypes();
4057
4058 addOperations<
4059#define GET_OP_LIST
4060#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4061 ,
4062#define GET_OP_LIST
4063#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4064 >();
4065
4066 // Support unknown operations because not all LLVM operations are registered.
4067 allowUnknownOperations();
4068 declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
4069}
4070
4071#define GET_OP_CLASSES
4072#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4073
4074#define GET_OP_CLASSES
4075#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4076
4077LogicalResult LLVMDialect::verifyDataLayoutString(
4078 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
4079 llvm::Expected<llvm::DataLayout> maybeDataLayout =
4080 llvm::DataLayout::parse(LayoutString: descr);
4081 if (maybeDataLayout)
4082 return success();
4083
4084 std::string message;
4085 llvm::raw_string_ostream messageStream(message);
4086 llvm::logAllUnhandledErrors(E: maybeDataLayout.takeError(), OS&: messageStream);
4087 reportError("invalid data layout descriptor: " + message);
4088 return failure();
4089}
4090
4091/// Verify LLVM dialect attributes.
4092LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
4093 NamedAttribute attr) {
4094 // If the data layout attribute is present, it must use the LLVM data layout
4095 // syntax. Try parsing it and report errors in case of failure. Users of this
4096 // attribute may assume it is well-formed and can pass it to the (asserting)
4097 // llvm::DataLayout constructor.
4098 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
4099 return success();
4100 if (auto stringAttr = llvm::dyn_cast<StringAttr>(Val: attr.getValue()))
4101 return verifyDataLayoutString(
4102 descr: stringAttr.getValue(),
4103 reportError: [op](const Twine &message) { op->emitOpError() << message.str(); });
4104
4105 return op->emitOpError() << "expected '"
4106 << LLVM::LLVMDialect::getDataLayoutAttrName()
4107 << "' to be a string attributes";
4108}
4109
4110LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
4111 Type paramType,
4112 NamedAttribute paramAttr) {
4113 // LLVM attribute may be attached to a result of operation that has not been
4114 // converted to LLVM dialect yet, so the result may have a type with unknown
4115 // representation in LLVM dialect type space. In this case we cannot verify
4116 // whether the attribute may be
4117 bool verifyValueType = isCompatibleType(paramType);
4118 StringAttr name = paramAttr.getName();
4119
4120 auto checkUnitAttrType = [&]() -> LogicalResult {
4121 if (!llvm::isa<UnitAttr>(Val: paramAttr.getValue()))
4122 return op->emitError() << name << " should be a unit attribute";
4123 return success();
4124 };
4125 auto checkTypeAttrType = [&]() -> LogicalResult {
4126 if (!llvm::isa<TypeAttr>(Val: paramAttr.getValue()))
4127 return op->emitError() << name << " should be a type attribute";
4128 return success();
4129 };
4130 auto checkIntegerAttrType = [&]() -> LogicalResult {
4131 if (!llvm::isa<IntegerAttr>(Val: paramAttr.getValue()))
4132 return op->emitError() << name << " should be an integer attribute";
4133 return success();
4134 };
4135 auto checkPointerType = [&]() -> LogicalResult {
4136 if (!llvm::isa<LLVMPointerType>(Val: paramType))
4137 return op->emitError()
4138 << name << " attribute attached to non-pointer LLVM type";
4139 return success();
4140 };
4141 auto checkIntegerType = [&]() -> LogicalResult {
4142 if (!llvm::isa<IntegerType>(Val: paramType))
4143 return op->emitError()
4144 << name << " attribute attached to non-integer LLVM type";
4145 return success();
4146 };
4147 auto checkPointerTypeMatches = [&]() -> LogicalResult {
4148 if (failed(Result: checkPointerType()))
4149 return failure();
4150
4151 return success();
4152 };
4153
4154 // Check a unit attribute that is attached to a pointer value.
4155 if (name == LLVMDialect::getNoAliasAttrName() ||
4156 name == LLVMDialect::getReadonlyAttrName() ||
4157 name == LLVMDialect::getReadnoneAttrName() ||
4158 name == LLVMDialect::getWriteOnlyAttrName() ||
4159 name == LLVMDialect::getNestAttrName() ||
4160 name == LLVMDialect::getNoCaptureAttrName() ||
4161 name == LLVMDialect::getNoFreeAttrName() ||
4162 name == LLVMDialect::getNonNullAttrName()) {
4163 if (failed(Result: checkUnitAttrType()))
4164 return failure();
4165 if (verifyValueType && failed(Result: checkPointerType()))
4166 return failure();
4167 return success();
4168 }
4169
4170 // Check a type attribute that is attached to a pointer value.
4171 if (name == LLVMDialect::getStructRetAttrName() ||
4172 name == LLVMDialect::getByValAttrName() ||
4173 name == LLVMDialect::getByRefAttrName() ||
4174 name == LLVMDialect::getElementTypeAttrName() ||
4175 name == LLVMDialect::getInAllocaAttrName() ||
4176 name == LLVMDialect::getPreallocatedAttrName()) {
4177 if (failed(Result: checkTypeAttrType()))
4178 return failure();
4179 if (verifyValueType && failed(Result: checkPointerTypeMatches()))
4180 return failure();
4181 return success();
4182 }
4183
4184 // Check a unit attribute that is attached to an integer value.
4185 if (name == LLVMDialect::getSExtAttrName() ||
4186 name == LLVMDialect::getZExtAttrName()) {
4187 if (failed(Result: checkUnitAttrType()))
4188 return failure();
4189 if (verifyValueType && failed(Result: checkIntegerType()))
4190 return failure();
4191 return success();
4192 }
4193
4194 // Check an integer attribute that is attached to a pointer value.
4195 if (name == LLVMDialect::getAlignAttrName() ||
4196 name == LLVMDialect::getDereferenceableAttrName() ||
4197 name == LLVMDialect::getDereferenceableOrNullAttrName()) {
4198 if (failed(Result: checkIntegerAttrType()))
4199 return failure();
4200 if (verifyValueType && failed(Result: checkPointerType()))
4201 return failure();
4202 return success();
4203 }
4204
4205 // Check an integer attribute that is attached to a pointer value.
4206 if (name == LLVMDialect::getStackAlignmentAttrName()) {
4207 if (failed(Result: checkIntegerAttrType()))
4208 return failure();
4209 return success();
4210 }
4211
4212 // Check a unit attribute that can be attached to arbitrary types.
4213 if (name == LLVMDialect::getNoUndefAttrName() ||
4214 name == LLVMDialect::getInRegAttrName() ||
4215 name == LLVMDialect::getReturnedAttrName())
4216 return checkUnitAttrType();
4217
4218 return success();
4219}
4220
4221/// Verify LLVMIR function argument attributes.
4222LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
4223 unsigned regionIdx,
4224 unsigned argIdx,
4225 NamedAttribute argAttr) {
4226 auto funcOp = dyn_cast<FunctionOpInterface>(Val: op);
4227 if (!funcOp)
4228 return success();
4229 Type argType = funcOp.getArgumentTypes()[argIdx];
4230
4231 return verifyParameterAttribute(op, paramType: argType, paramAttr: argAttr);
4232}
4233
4234LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
4235 unsigned regionIdx,
4236 unsigned resIdx,
4237 NamedAttribute resAttr) {
4238 auto funcOp = dyn_cast<FunctionOpInterface>(Val: op);
4239 if (!funcOp)
4240 return success();
4241 Type resType = funcOp.getResultTypes()[resIdx];
4242
4243 // Check to see if this function has a void return with a result attribute
4244 // to it. It isn't clear what semantics we would assign to that.
4245 if (llvm::isa<LLVMVoidType>(Val: resType))
4246 return op->emitError() << "cannot attach result attributes to functions "
4247 "with a void return";
4248
4249 // Check to see if this attribute is allowed as a result attribute. Only
4250 // explicitly forbidden LLVM attributes will cause an error.
4251 auto name = resAttr.getName();
4252 if (name == LLVMDialect::getAllocAlignAttrName() ||
4253 name == LLVMDialect::getAllocatedPointerAttrName() ||
4254 name == LLVMDialect::getByValAttrName() ||
4255 name == LLVMDialect::getByRefAttrName() ||
4256 name == LLVMDialect::getInAllocaAttrName() ||
4257 name == LLVMDialect::getNestAttrName() ||
4258 name == LLVMDialect::getNoCaptureAttrName() ||
4259 name == LLVMDialect::getNoFreeAttrName() ||
4260 name == LLVMDialect::getPreallocatedAttrName() ||
4261 name == LLVMDialect::getReadnoneAttrName() ||
4262 name == LLVMDialect::getReadonlyAttrName() ||
4263 name == LLVMDialect::getReturnedAttrName() ||
4264 name == LLVMDialect::getStackAlignmentAttrName() ||
4265 name == LLVMDialect::getStructRetAttrName() ||
4266 name == LLVMDialect::getWriteOnlyAttrName())
4267 return op->emitError() << name << " is not a valid result attribute";
4268 return verifyParameterAttribute(op, paramType: resType, paramAttr: resAttr);
4269}
4270
4271Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
4272 Type type, Location loc) {
4273 // If this was folded from an operation other than llvm.mlir.constant, it
4274 // should be materialized as such. Note that an llvm.mlir.zero may fold into
4275 // a builtin zero attribute and thus will materialize as a llvm.mlir.constant.
4276 if (auto symbol = dyn_cast<FlatSymbolRefAttr>(Val&: value))
4277 if (isa<LLVM::LLVMPointerType>(Val: type))
4278 return builder.create<LLVM::AddressOfOp>(location: loc, args&: type, args&: symbol);
4279 if (isa<LLVM::UndefAttr>(Val: value))
4280 return builder.create<LLVM::UndefOp>(location: loc, args&: type);
4281 if (isa<LLVM::PoisonAttr>(Val: value))
4282 return builder.create<LLVM::PoisonOp>(location: loc, args&: type);
4283 if (isa<LLVM::ZeroAttr>(Val: value))
4284 return builder.create<LLVM::ZeroOp>(location: loc, args&: type);
4285 // Otherwise try materializing it as a regular llvm.mlir.constant op.
4286 return LLVM::ConstantOp::materialize(builder, value, type, loc);
4287}
4288
4289//===----------------------------------------------------------------------===//
4290// Utility functions.
4291//===----------------------------------------------------------------------===//
4292
4293Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
4294 StringRef name, StringRef value,
4295 LLVM::Linkage linkage) {
4296 assert(builder.getInsertionBlock() &&
4297 builder.getInsertionBlock()->getParentOp() &&
4298 "expected builder to point to a block constrained in an op");
4299 auto module =
4300 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
4301 assert(module && "builder points to an op outside of a module");
4302
4303 // Create the global at the entry of the module.
4304 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
4305 MLIRContext *ctx = builder.getContext();
4306 auto type = LLVM::LLVMArrayType::get(elementType: IntegerType::get(context: ctx, width: 8), numElements: value.size());
4307 auto global = moduleBuilder.create<LLVM::GlobalOp>(
4308 location: loc, args&: type, /*isConstant=*/args: true, args&: linkage, args&: name,
4309 args: builder.getStringAttr(bytes: value), /*alignment=*/args: 0);
4310
4311 LLVMPointerType ptrType = LLVMPointerType::get(context: ctx);
4312 // Get the pointer to the first character in the global string.
4313 Value globalPtr =
4314 builder.create<LLVM::AddressOfOp>(location: loc, args&: ptrType, args: global.getSymNameAttr());
4315 return builder.create<LLVM::GEPOp>(location: loc, args&: ptrType, args&: type, args&: globalPtr,
4316 args: ArrayRef<GEPArg>{0, 0});
4317}
4318
4319bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
4320 return op->hasTrait<OpTrait::SymbolTable>() &&
4321 op->hasTrait<OpTrait::IsIsolatedFromAbove>();
4322}
4323

source code of mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp