1//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the LLVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "TypeDetail.h"
16#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
17#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
18#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/Interfaces/FunctionImplementation.h"
27#include "mlir/Transforms/InliningUtils.h"
28
29#include "llvm/ADT/SCCIterator.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/AsmParser/Parser.h"
32#include "llvm/Bitcode/BitcodeReader.h"
33#include "llvm/Bitcode/BitcodeWriter.h"
34#include "llvm/IR/Attributes.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/Type.h"
37#include "llvm/Support/Error.h"
38#include "llvm/Support/Mutex.h"
39#include "llvm/Support/SourceMgr.h"
40
41#include <numeric>
42#include <optional>
43
44using namespace mlir;
45using namespace mlir::LLVM;
46using mlir::LLVM::cconv::getMaxEnumValForCConv;
47using mlir::LLVM::linkage::getMaxEnumValForLinkage;
48using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
49
50#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
51
52//===----------------------------------------------------------------------===//
53// Attribute Helpers
54//===----------------------------------------------------------------------===//
55
56static constexpr const char kElemTypeAttrName[] = "elem_type";
57
58static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
59 SmallVector<NamedAttribute, 8> filteredAttrs(
60 llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) {
61 if (attr.getName() == "fastmathFlags") {
62 auto defAttr =
63 FastmathFlagsAttr::get(attr.getValue().getContext(), {});
64 return defAttr != attr.getValue();
65 }
66 return true;
67 }));
68 return filteredAttrs;
69}
70
71/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
72/// fully defined llvm.func.
73static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
74 Operation *op,
75 SymbolTableCollection &symbolTable) {
76 StringRef name = symbol.getValue();
77 auto func =
78 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
79 if (!func)
80 return op->emitOpError(message: "'")
81 << name << "' does not reference a valid LLVM function";
82 if (func.isExternal())
83 return op->emitOpError(message: "'") << name << "' does not have a definition";
84 return success();
85}
86
87/// Returns a boolean type that has the same shape as `type`. It supports both
88/// fixed size vectors as well as scalable vectors.
89static Type getI1SameShape(Type type) {
90 Type i1Type = IntegerType::get(type.getContext(), 1);
91 if (LLVM::isCompatibleVectorType(type))
92 return LLVM::getVectorType(elementType: i1Type, numElements: LLVM::getVectorNumElements(type));
93 return i1Type;
94}
95
96// Parses one of the keywords provided in the list `keywords` and returns the
97// position of the parsed keyword in the list. If none of the keywords from the
98// list is parsed, returns -1.
99static int parseOptionalKeywordAlternative(OpAsmParser &parser,
100 ArrayRef<StringRef> keywords) {
101 for (const auto &en : llvm::enumerate(First&: keywords)) {
102 if (succeeded(Result: parser.parseOptionalKeyword(keyword: en.value())))
103 return en.index();
104 }
105 return -1;
106}
107
108namespace {
109template <typename Ty>
110struct EnumTraits {};
111
112#define REGISTER_ENUM_TYPE(Ty) \
113 template <> \
114 struct EnumTraits<Ty> { \
115 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
116 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
117 }
118
119REGISTER_ENUM_TYPE(Linkage);
120REGISTER_ENUM_TYPE(UnnamedAddr);
121REGISTER_ENUM_TYPE(CConv);
122REGISTER_ENUM_TYPE(TailCallKind);
123REGISTER_ENUM_TYPE(Visibility);
124} // namespace
125
126/// Parse an enum from the keyword, or default to the provided default value.
127/// The return type is the enum type by default, unless overridden with the
128/// second template argument.
129template <typename EnumTy, typename RetTy = EnumTy>
130static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
131 OperationState &result,
132 EnumTy defaultValue) {
133 SmallVector<StringRef, 10> names;
134 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
135 names.push_back(Elt: EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
136
137 int index = parseOptionalKeywordAlternative(parser, keywords: names);
138 if (index == -1)
139 return static_cast<RetTy>(defaultValue);
140 return static_cast<RetTy>(index);
141}
142
143//===----------------------------------------------------------------------===//
144// Operand bundle helpers.
145//===----------------------------------------------------------------------===//
146
147static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
148 TypeRange operandTypes, StringRef tag) {
149 p.printString(string: tag);
150 p << "(";
151
152 if (!operands.empty()) {
153 p.printOperands(container: operands);
154 p << " : ";
155 llvm::interleaveComma(c: operandTypes, os&: p);
156 }
157
158 p << ")";
159}
160
161static void printOpBundles(OpAsmPrinter &p, Operation *op,
162 OperandRangeRange opBundleOperands,
163 TypeRangeRange opBundleOperandTypes,
164 std::optional<ArrayAttr> opBundleTags) {
165 if (opBundleOperands.empty())
166 return;
167 assert(opBundleTags && "expect operand bundle tags");
168
169 p << "[";
170 llvm::interleaveComma(
171 llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p,
172 [&p](auto bundle) {
173 auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
174 printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
175 bundleTag);
176 });
177 p << "]";
178}
179
180static ParseResult parseOneOpBundle(
181 OpAsmParser &p,
182 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
183 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
184 SmallVector<Attribute> &opBundleTags) {
185 SMLoc currentParserLoc = p.getCurrentLocation();
186 SmallVector<OpAsmParser::UnresolvedOperand> operands;
187 SmallVector<Type> types;
188 std::string tag;
189
190 if (p.parseString(string: &tag))
191 return p.emitError(loc: currentParserLoc, message: "expect operand bundle tag");
192
193 if (p.parseLParen())
194 return failure();
195
196 if (p.parseOptionalRParen()) {
197 if (p.parseOperandList(result&: operands) || p.parseColon() ||
198 p.parseTypeList(result&: types) || p.parseRParen())
199 return failure();
200 }
201
202 opBundleOperands.push_back(Elt: std::move(operands));
203 opBundleOperandTypes.push_back(Elt: std::move(types));
204 opBundleTags.push_back(StringAttr::get(p.getContext(), tag));
205
206 return success();
207}
208
209static std::optional<ParseResult> parseOpBundles(
210 OpAsmParser &p,
211 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands,
212 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
213 ArrayAttr &opBundleTags) {
214 if (p.parseOptionalLSquare())
215 return std::nullopt;
216
217 if (succeeded(Result: p.parseOptionalRSquare()))
218 return success();
219
220 SmallVector<Attribute> opBundleTagAttrs;
221 auto bundleParser = [&] {
222 return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
223 opBundleTags&: opBundleTagAttrs);
224 };
225 if (p.parseCommaSeparatedList(parseElementFn: bundleParser))
226 return failure();
227
228 if (p.parseRSquare())
229 return failure();
230
231 opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs);
232
233 return success();
234}
235
236//===----------------------------------------------------------------------===//
237// Printing, parsing, folding and builder for LLVM::CmpOp.
238//===----------------------------------------------------------------------===//
239
240void ICmpOp::print(OpAsmPrinter &p) {
241 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
242 << ", " << getOperand(1);
243 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
244 p << " : " << getLhs().getType();
245}
246
247void FCmpOp::print(OpAsmPrinter &p) {
248 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
249 << ", " << getOperand(1);
250 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
251 p << " : " << getLhs().getType();
252}
253
254// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
255// attribute-dict? `:` type
256// <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
257// attribute-dict? `:` type
258template <typename CmpPredicateType>
259static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
260 StringAttr predicateAttr;
261 OpAsmParser::UnresolvedOperand lhs, rhs;
262 Type type;
263 SMLoc predicateLoc, trailingTypeLoc;
264 if (parser.getCurrentLocation(loc: &predicateLoc) ||
265 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
266 parser.parseOperand(result&: lhs) || parser.parseComma() ||
267 parser.parseOperand(result&: rhs) ||
268 parser.parseOptionalAttrDict(result&: result.attributes) || parser.parseColon() ||
269 parser.getCurrentLocation(loc: &trailingTypeLoc) || parser.parseType(result&: type) ||
270 parser.resolveOperand(operand: lhs, type, result&: result.operands) ||
271 parser.resolveOperand(operand: rhs, type, result&: result.operands))
272 return failure();
273
274 // Replace the string attribute `predicate` with an integer attribute.
275 int64_t predicateValue = 0;
276 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
277 std::optional<ICmpPredicate> predicate =
278 symbolizeICmpPredicate(predicateAttr.getValue());
279 if (!predicate)
280 return parser.emitError(loc: predicateLoc)
281 << "'" << predicateAttr.getValue()
282 << "' is an incorrect value of the 'predicate' attribute";
283 predicateValue = static_cast<int64_t>(*predicate);
284 } else {
285 std::optional<FCmpPredicate> predicate =
286 symbolizeFCmpPredicate(predicateAttr.getValue());
287 if (!predicate)
288 return parser.emitError(loc: predicateLoc)
289 << "'" << predicateAttr.getValue()
290 << "' is an incorrect value of the 'predicate' attribute";
291 predicateValue = static_cast<int64_t>(*predicate);
292 }
293
294 result.attributes.set("predicate",
295 parser.getBuilder().getI64IntegerAttr(predicateValue));
296
297 // The result type is either i1 or a vector type <? x i1> if the inputs are
298 // vectors.
299 if (!isCompatibleType(type))
300 return parser.emitError(loc: trailingTypeLoc,
301 message: "expected LLVM dialect-compatible type");
302 result.addTypes(newTypes: getI1SameShape(type));
303 return success();
304}
305
306ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
307 return parseCmpOp<ICmpPredicate>(parser, result);
308}
309
310ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
311 return parseCmpOp<FCmpPredicate>(parser, result);
312}
313
314/// Returns a scalar or vector boolean attribute of the given type.
315static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
316 auto boolAttr = BoolAttr::get(context: ctx, value);
317 ShapedType shapedType = dyn_cast<ShapedType>(type);
318 if (!shapedType)
319 return boolAttr;
320 return DenseElementsAttr::get(shapedType, boolAttr);
321}
322
323OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
324 if (getPredicate() != ICmpPredicate::eq &&
325 getPredicate() != ICmpPredicate::ne)
326 return {};
327
328 // cmpi(eq/ne, x, x) -> true/false
329 if (getLhs() == getRhs())
330 return getBoolAttribute(getType(), getContext(),
331 getPredicate() == ICmpPredicate::eq);
332
333 // cmpi(eq/ne, alloca, null) -> false/true
334 if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
335 return getBoolAttribute(getType(), getContext(),
336 getPredicate() == ICmpPredicate::ne);
337
338 // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
339 if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
340 Value lhs = getLhs();
341 Value rhs = getRhs();
342 getLhsMutable().assign(rhs);
343 getRhsMutable().assign(lhs);
344 return getResult();
345 }
346
347 return {};
348}
349
350//===----------------------------------------------------------------------===//
351// Printing, parsing and verification for LLVM::AllocaOp.
352//===----------------------------------------------------------------------===//
353
354void AllocaOp::print(OpAsmPrinter &p) {
355 auto funcTy =
356 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
357
358 if (getInalloca())
359 p << " inalloca";
360
361 p << ' ' << getArraySize() << " x " << getElemType();
362 if (getAlignment() && *getAlignment() != 0)
363 p.printOptionalAttrDict((*this)->getAttrs(),
364 {kElemTypeAttrName, getInallocaAttrName()});
365 else
366 p.printOptionalAttrDict(
367 (*this)->getAttrs(),
368 {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
369 p << " : " << funcTy;
370}
371
372// <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
373// attribute-dict? `:` type `,` type
374ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
375 OpAsmParser::UnresolvedOperand arraySize;
376 Type type, elemType;
377 SMLoc trailingTypeLoc;
378
379 if (succeeded(parser.parseOptionalKeyword("inalloca")))
380 result.addAttribute(getInallocaAttrName(result.name),
381 UnitAttr::get(parser.getContext()));
382
383 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
384 parser.parseType(elemType) ||
385 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
386 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
387 return failure();
388
389 std::optional<NamedAttribute> alignmentAttr =
390 result.attributes.getNamed("alignment");
391 if (alignmentAttr.has_value()) {
392 auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
393 if (!alignmentInt)
394 return parser.emitError(parser.getNameLoc(),
395 "expected integer alignment");
396 if (alignmentInt.getValue().isZero())
397 result.attributes.erase("alignment");
398 }
399
400 // Extract the result type from the trailing function type.
401 auto funcType = llvm::dyn_cast<FunctionType>(type);
402 if (!funcType || funcType.getNumInputs() != 1 ||
403 funcType.getNumResults() != 1)
404 return parser.emitError(
405 trailingTypeLoc,
406 "expected trailing function type with one argument and one result");
407
408 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
409 return failure();
410
411 Type resultType = funcType.getResult(0);
412 if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
413 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
414
415 result.addTypes({funcType.getResult(0)});
416 return success();
417}
418
419LogicalResult AllocaOp::verify() {
420 // Only certain target extension types can be used in 'alloca'.
421 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
422 targetExtType && !targetExtType.supportsMemOps())
423 return emitOpError()
424 << "this target extension type cannot be used in alloca";
425
426 return success();
427}
428
429//===----------------------------------------------------------------------===//
430// LLVM::BrOp
431//===----------------------------------------------------------------------===//
432
433SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
434 assert(index == 0 && "invalid successor index");
435 return SuccessorOperands(getDestOperandsMutable());
436}
437
438//===----------------------------------------------------------------------===//
439// LLVM::CondBrOp
440//===----------------------------------------------------------------------===//
441
442SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
443 assert(index < getNumSuccessors() && "invalid successor index");
444 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
445 : getFalseDestOperandsMutable());
446}
447
448void CondBrOp::build(OpBuilder &builder, OperationState &result,
449 Value condition, Block *trueDest, ValueRange trueOperands,
450 Block *falseDest, ValueRange falseOperands,
451 std::optional<std::pair<uint32_t, uint32_t>> weights) {
452 DenseI32ArrayAttr weightsAttr;
453 if (weights)
454 weightsAttr =
455 builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
456 static_cast<int32_t>(weights->second)});
457
458 build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
459 /*loop_annotation=*/{}, trueDest, falseDest);
460}
461
462//===----------------------------------------------------------------------===//
463// LLVM::SwitchOp
464//===----------------------------------------------------------------------===//
465
466void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
467 Block *defaultDestination, ValueRange defaultOperands,
468 DenseIntElementsAttr caseValues,
469 BlockRange caseDestinations,
470 ArrayRef<ValueRange> caseOperands,
471 ArrayRef<int32_t> branchWeights) {
472 DenseI32ArrayAttr weightsAttr;
473 if (!branchWeights.empty())
474 weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
475
476 build(builder, result, value, defaultOperands, caseOperands, caseValues,
477 weightsAttr, defaultDestination, caseDestinations);
478}
479
480void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
481 Block *defaultDestination, ValueRange defaultOperands,
482 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
483 ArrayRef<ValueRange> caseOperands,
484 ArrayRef<int32_t> branchWeights) {
485 DenseIntElementsAttr caseValuesAttr;
486 if (!caseValues.empty()) {
487 ShapedType caseValueType = VectorType::get(
488 static_cast<int64_t>(caseValues.size()), value.getType());
489 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
490 }
491
492 build(builder, result, value, defaultDestination, defaultOperands,
493 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
494}
495
496void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
497 Block *defaultDestination, ValueRange defaultOperands,
498 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
499 ArrayRef<ValueRange> caseOperands,
500 ArrayRef<int32_t> branchWeights) {
501 DenseIntElementsAttr caseValuesAttr;
502 if (!caseValues.empty()) {
503 ShapedType caseValueType = VectorType::get(
504 static_cast<int64_t>(caseValues.size()), value.getType());
505 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
506 }
507
508 build(builder, result, value, defaultDestination, defaultOperands,
509 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
510}
511
512/// <cases> ::= `[` (case (`,` case )* )? `]`
513/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
514static ParseResult parseSwitchOpCases(
515 OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
516 SmallVectorImpl<Block *> &caseDestinations,
517 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
518 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
519 if (failed(Result: parser.parseLSquare()))
520 return failure();
521 if (succeeded(Result: parser.parseOptionalRSquare()))
522 return success();
523 SmallVector<APInt> values;
524 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
525 auto parseCase = [&]() {
526 int64_t value = 0;
527 if (failed(Result: parser.parseInteger(result&: value)))
528 return failure();
529 values.push_back(Elt: APInt(bitWidth, value, /*isSigned=*/true));
530
531 Block *destination;
532 SmallVector<OpAsmParser::UnresolvedOperand> operands;
533 SmallVector<Type> operandTypes;
534 if (parser.parseColon() || parser.parseSuccessor(dest&: destination))
535 return failure();
536 if (!parser.parseOptionalLParen()) {
537 if (parser.parseOperandList(result&: operands, delimiter: OpAsmParser::Delimiter::None,
538 /*allowResultNumber=*/false) ||
539 parser.parseColonTypeList(result&: operandTypes) || parser.parseRParen())
540 return failure();
541 }
542 caseDestinations.push_back(Elt: destination);
543 caseOperands.emplace_back(Args&: operands);
544 caseOperandTypes.emplace_back(Args&: operandTypes);
545 return success();
546 };
547 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: parseCase)))
548 return failure();
549
550 ShapedType caseValueType =
551 VectorType::get(static_cast<int64_t>(values.size()), flagType);
552 caseValues = DenseIntElementsAttr::get(caseValueType, values);
553 return parser.parseRSquare();
554}
555
556static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
557 DenseIntElementsAttr caseValues,
558 SuccessorRange caseDestinations,
559 OperandRangeRange caseOperands,
560 const TypeRangeRange &caseOperandTypes) {
561 p << '[';
562 p.printNewline();
563 if (!caseValues) {
564 p << ']';
565 return;
566 }
567
568 size_t index = 0;
569 llvm::interleave(
570 c: llvm::zip(t&: caseValues, u&: caseDestinations),
571 each_fn: [&](auto i) {
572 p << " ";
573 p << std::get<0>(i);
574 p << ": ";
575 p.printSuccessorAndUseList(successor: std::get<1>(i), succOperands: caseOperands[index++]);
576 },
577 between_fn: [&] {
578 p << ',';
579 p.printNewline();
580 });
581 p.printNewline();
582 p << ']';
583}
584
585LogicalResult SwitchOp::verify() {
586 if ((!getCaseValues() && !getCaseDestinations().empty()) ||
587 (getCaseValues() &&
588 getCaseValues()->size() !=
589 static_cast<int64_t>(getCaseDestinations().size())))
590 return emitOpError("expects number of case values to match number of "
591 "case destinations");
592 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
593 return emitError("expects number of branch weights to match number of "
594 "successors: ")
595 << getBranchWeights()->size() << " vs " << getNumSuccessors();
596 if (getCaseValues() &&
597 getValue().getType() != getCaseValues()->getElementType())
598 return emitError("expects case value type to match condition value type");
599 return success();
600}
601
602SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
603 assert(index < getNumSuccessors() && "invalid successor index");
604 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
605 : getCaseOperandsMutable(index - 1));
606}
607
608//===----------------------------------------------------------------------===//
609// Code for LLVM::GEPOp.
610//===----------------------------------------------------------------------===//
611
612constexpr int32_t GEPOp::kDynamicIndex;
613
614GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
615 return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
616 getDynamicIndices());
617}
618
619/// Returns the elemental type of any LLVM-compatible vector type or self.
620static Type extractVectorElementType(Type type) {
621 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
622 return vectorType.getElementType();
623 return type;
624}
625
626/// Destructures the 'indices' parameter into 'rawConstantIndices' and
627/// 'dynamicIndices', encoding the former in the process. In the process,
628/// dynamic indices which are used to index into a structure type are converted
629/// to constant indices when possible. To do this, the GEPs element type should
630/// be passed as first parameter.
631static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
632 SmallVectorImpl<int32_t> &rawConstantIndices,
633 SmallVectorImpl<Value> &dynamicIndices) {
634 for (const GEPArg &iter : indices) {
635 // If the thing we are currently indexing into is a struct we must turn
636 // any integer constants into constant indices. If this is not possible
637 // we don't do anything here. The verifier will catch it and emit a proper
638 // error. All other canonicalization is done in the fold method.
639 bool requiresConst = !rawConstantIndices.empty() &&
640 isa_and_nonnull<LLVMStructType>(currType);
641 if (Value val = llvm::dyn_cast_if_present<Value>(Val: iter)) {
642 APInt intC;
643 if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
644 intC.isSignedIntN(kGEPConstantBitWidth)) {
645 rawConstantIndices.push_back(Elt: intC.getSExtValue());
646 } else {
647 rawConstantIndices.push_back(GEPOp::kDynamicIndex);
648 dynamicIndices.push_back(Elt: val);
649 }
650 } else {
651 rawConstantIndices.push_back(Elt: cast<GEPConstantIndex>(Val: iter));
652 }
653
654 // Skip for very first iteration of this loop. First index does not index
655 // within the aggregates, but is just a pointer offset.
656 if (rawConstantIndices.size() == 1 || !currType)
657 continue;
658
659 currType = TypeSwitch<Type, Type>(currType)
660 .Case<VectorType, LLVMArrayType>([](auto containerType) {
661 return containerType.getElementType();
662 })
663 .Case([&](LLVMStructType structType) -> Type {
664 int64_t memberIndex = rawConstantIndices.back();
665 if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
666 structType.getBody().size())
667 return structType.getBody()[memberIndex];
668 return nullptr;
669 })
670 .Default(Type(nullptr));
671 }
672}
673
674void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
675 Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
676 GEPNoWrapFlags noWrapFlags,
677 ArrayRef<NamedAttribute> attributes) {
678 SmallVector<int32_t> rawConstantIndices;
679 SmallVector<Value> dynamicIndices;
680 destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
681
682 result.addTypes(resultType);
683 result.addAttributes(attributes);
684 result.getOrAddProperties<Properties>().rawConstantIndices =
685 builder.getDenseI32ArrayAttr(rawConstantIndices);
686 result.getOrAddProperties<Properties>().noWrapFlags = noWrapFlags;
687 result.getOrAddProperties<Properties>().elem_type =
688 TypeAttr::get(elementType);
689 result.addOperands(basePtr);
690 result.addOperands(dynamicIndices);
691}
692
693void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
694 Type elementType, Value basePtr, ValueRange indices,
695 GEPNoWrapFlags noWrapFlags,
696 ArrayRef<NamedAttribute> attributes) {
697 build(builder, result, resultType, elementType, basePtr,
698 SmallVector<GEPArg>(indices), noWrapFlags, attributes);
699}
700
701static ParseResult
702parseGEPIndices(OpAsmParser &parser,
703 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
704 DenseI32ArrayAttr &rawConstantIndices) {
705 SmallVector<int32_t> constantIndices;
706
707 auto idxParser = [&]() -> ParseResult {
708 int32_t constantIndex;
709 OptionalParseResult parsedInteger =
710 parser.parseOptionalInteger(result&: constantIndex);
711 if (parsedInteger.has_value()) {
712 if (failed(Result: parsedInteger.value()))
713 return failure();
714 constantIndices.push_back(Elt: constantIndex);
715 return success();
716 }
717
718 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
719 return parser.parseOperand(result&: indices.emplace_back());
720 };
721 if (parser.parseCommaSeparatedList(parseElementFn: idxParser))
722 return failure();
723
724 rawConstantIndices =
725 DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
726 return success();
727}
728
729static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
730 OperandRange indices,
731 DenseI32ArrayAttr rawConstantIndices) {
732 llvm::interleaveComma(
733 c: GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), os&: printer,
734 each_fn: [&](PointerUnion<IntegerAttr, Value> cst) {
735 if (Value val = llvm::dyn_cast_if_present<Value>(Val&: cst))
736 printer.printOperand(value: val);
737 else
738 printer << cast<IntegerAttr>(cst).getInt();
739 });
740}
741
742/// For the given `indices`, check if they comply with `baseGEPType`,
743/// especially check against LLVMStructTypes nested within.
744static LogicalResult
745verifyStructIndices(Type baseGEPType, unsigned indexPos,
746 GEPIndicesAdaptor<ValueRange> indices,
747 function_ref<InFlightDiagnostic()> emitOpError) {
748 if (indexPos >= indices.size())
749 // Stop searching
750 return success();
751
752 return TypeSwitch<Type, LogicalResult>(baseGEPType)
753 .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
754 auto attr = dyn_cast<IntegerAttr>(indices[indexPos]);
755 if (!attr)
756 return emitOpError() << "expected index " << indexPos
757 << " indexing a struct to be constant";
758
759 int32_t gepIndex = attr.getInt();
760 ArrayRef<Type> elementTypes = structType.getBody();
761 if (gepIndex < 0 ||
762 static_cast<size_t>(gepIndex) >= elementTypes.size())
763 return emitOpError() << "index " << indexPos
764 << " indexing a struct is out of bounds";
765
766 // Instead of recursively going into every children types, we only
767 // dive into the one indexed by gepIndex.
768 return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
769 indices, emitOpError);
770 })
771 .Case<VectorType, LLVMArrayType>(
772 [&](auto containerType) -> LogicalResult {
773 return verifyStructIndices(containerType.getElementType(),
774 indexPos + 1, indices, emitOpError);
775 })
776 .Default([&](auto otherType) -> LogicalResult {
777 return emitOpError()
778 << "type " << otherType << " cannot be indexed (index #"
779 << indexPos << ")";
780 });
781}
782
783/// Driver function around `verifyStructIndices`.
784static LogicalResult
785verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
786 function_ref<InFlightDiagnostic()> emitOpError) {
787 return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
788}
789
790LogicalResult LLVM::GEPOp::verify() {
791 if (static_cast<size_t>(
792 llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
793 getDynamicIndices().size())
794 return emitOpError("expected as many dynamic indices as specified in '")
795 << getRawConstantIndicesAttrName().getValue() << "'";
796
797 if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag)
798 return emitOpError("'inbounds_flag' cannot be used directly.");
799
800 return verifyStructIndices(getElemType(), getIndices(),
801 [&] { return emitOpError(); });
802}
803
804//===----------------------------------------------------------------------===//
805// LoadOp
806//===----------------------------------------------------------------------===//
807
808void LoadOp::getEffects(
809 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
810 &effects) {
811 effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
812 // Volatile operations can have target-specific read-write effects on
813 // memory besides the one referred to by the pointer operand.
814 // Similarly, atomic operations that are monotonic or stricter cause
815 // synchronization that from a language point-of-view, are arbitrary
816 // read-writes into memory.
817 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
818 getOrdering() != AtomicOrdering::unordered)) {
819 effects.emplace_back(MemoryEffects::Write::get());
820 effects.emplace_back(MemoryEffects::Read::get());
821 }
822}
823
824/// Returns true if the given type is supported by atomic operations. All
825/// integer, float, and pointer types with a power-of-two bitsize and a minimal
826/// size of 8 bits are supported.
827static bool isTypeCompatibleWithAtomicOp(Type type,
828 const DataLayout &dataLayout) {
829 if (!isa<IntegerType, LLVMPointerType>(type))
830 if (!isCompatibleFloatingPointType(type))
831 return false;
832
833 llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(t: type);
834 if (bitWidth.isScalable())
835 return false;
836 // Needs to be at least 8 bits and a power of two.
837 return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0;
838}
839
840/// Verifies the attributes and the type of atomic memory access operations.
841template <typename OpTy>
842LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
843 ArrayRef<AtomicOrdering> unsupportedOrderings) {
844 if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
845 DataLayout dataLayout = DataLayout::closest(op: memOp);
846 if (!isTypeCompatibleWithAtomicOp(type: valueType, dataLayout))
847 return memOp.emitOpError("unsupported type ")
848 << valueType << " for atomic access";
849 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
850 return memOp.emitOpError("unsupported ordering '")
851 << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
852 if (!memOp.getAlignment())
853 return memOp.emitOpError("expected alignment for atomic access");
854 return success();
855 }
856 if (memOp.getSyncscope())
857 return memOp.emitOpError(
858 "expected syncscope to be null for non-atomic access");
859 return success();
860}
861
862LogicalResult LoadOp::verify() {
863 Type valueType = getResult().getType();
864 return verifyAtomicMemOp(*this, valueType,
865 {AtomicOrdering::release, AtomicOrdering::acq_rel});
866}
867
868void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
869 Value addr, unsigned alignment, bool isVolatile,
870 bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
871 AtomicOrdering ordering, StringRef syncscope) {
872 build(builder, state, type, addr,
873 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
874 isNonTemporal, isInvariant, isInvariantGroup, ordering,
875 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
876 /*dereferenceable=*/nullptr,
877 /*access_groups=*/nullptr,
878 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
879 /*tbaa=*/nullptr);
880}
881
882//===----------------------------------------------------------------------===//
883// StoreOp
884//===----------------------------------------------------------------------===//
885
886void StoreOp::getEffects(
887 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
888 &effects) {
889 effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
890 // Volatile operations can have target-specific read-write effects on
891 // memory besides the one referred to by the pointer operand.
892 // Similarly, atomic operations that are monotonic or stricter cause
893 // synchronization that from a language point-of-view, are arbitrary
894 // read-writes into memory.
895 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
896 getOrdering() != AtomicOrdering::unordered)) {
897 effects.emplace_back(MemoryEffects::Write::get());
898 effects.emplace_back(MemoryEffects::Read::get());
899 }
900}
901
902LogicalResult StoreOp::verify() {
903 Type valueType = getValue().getType();
904 return verifyAtomicMemOp(*this, valueType,
905 {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
906}
907
908void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
909 Value addr, unsigned alignment, bool isVolatile,
910 bool isNonTemporal, bool isInvariantGroup,
911 AtomicOrdering ordering, StringRef syncscope) {
912 build(builder, state, value, addr,
913 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
914 isNonTemporal, isInvariantGroup, ordering,
915 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
916 /*access_groups=*/nullptr,
917 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
918}
919
920//===----------------------------------------------------------------------===//
921// CallOp
922//===----------------------------------------------------------------------===//
923
924/// Gets the MLIR Op-like result types of a LLVMFunctionType.
925static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
926 SmallVector<Type, 1> results;
927 Type resultType = calleeType.getReturnType();
928 if (!isa<LLVM::LLVMVoidType>(Val: resultType))
929 results.push_back(Elt: resultType);
930 return results;
931}
932
933/// Gets the variadic callee type for a LLVMFunctionType.
934static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
935 return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
936}
937
938/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
939static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
940 ValueRange args) {
941 Type resultType;
942 if (results.empty())
943 resultType = LLVMVoidType::get(ctx: context);
944 else
945 resultType = results.front();
946 return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
947 /*isVarArg=*/false);
948}
949
950void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
951 StringRef callee, ValueRange args) {
952 build(builder, state, results, builder.getStringAttr(callee), args);
953}
954
955void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
956 StringAttr callee, ValueRange args) {
957 build(builder, state, results, SymbolRefAttr::get(callee), args);
958}
959
960void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
961 FlatSymbolRefAttr callee, ValueRange args) {
962 assert(callee && "expected non-null callee in direct call builder");
963 build(builder, state, results,
964 /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
965 /*branch_weights=*/nullptr,
966 /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
967 /*memory_effects=*/nullptr,
968 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
969 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
970 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
971 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
972 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
973 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
974 /*inline_hint=*/nullptr);
975}
976
977void CallOp::build(OpBuilder &builder, OperationState &state,
978 LLVMFunctionType calleeType, StringRef callee,
979 ValueRange args) {
980 build(builder, state, calleeType, builder.getStringAttr(callee), args);
981}
982
983void CallOp::build(OpBuilder &builder, OperationState &state,
984 LLVMFunctionType calleeType, StringAttr callee,
985 ValueRange args) {
986 build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
987}
988
989void CallOp::build(OpBuilder &builder, OperationState &state,
990 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
991 ValueRange args) {
992 build(builder, state, getCallOpResultTypes(calleeType),
993 getCallOpVarCalleeType(calleeType), callee, args,
994 /*fastmathFlags=*/nullptr,
995 /*branch_weights=*/nullptr, /*CConv=*/nullptr,
996 /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
997 /*convergent=*/nullptr,
998 /*no_unwind=*/nullptr, /*will_return=*/nullptr,
999 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1000 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1001 /*access_groups=*/nullptr,
1002 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1003 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1004 /*inline_hint=*/nullptr);
1005}
1006
1007void CallOp::build(OpBuilder &builder, OperationState &state,
1008 LLVMFunctionType calleeType, ValueRange args) {
1009 build(builder, state, getCallOpResultTypes(calleeType),
1010 getCallOpVarCalleeType(calleeType),
1011 /*callee=*/nullptr, args,
1012 /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1013 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1014 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1015 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1016 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1017 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1018 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1019 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1020 /*inline_hint=*/nullptr);
1021}
1022
1023void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1024 ValueRange args) {
1025 auto calleeType = func.getFunctionType();
1026 build(builder, state, getCallOpResultTypes(calleeType),
1027 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
1028 /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1029 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1030 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1031 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1032 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1033 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1034 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1035 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1036 /*inline_hint=*/nullptr);
1037}
1038
1039CallInterfaceCallable CallOp::getCallableForCallee() {
1040 // Direct call.
1041 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1042 return calleeAttr;
1043 // Indirect call, callee Value is the first operand.
1044 return getOperand(0);
1045}
1046
1047void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1048 // Direct call.
1049 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1050 auto symRef = cast<SymbolRefAttr>(callee);
1051 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1052 }
1053 // Indirect call, callee Value is the first operand.
1054 return setOperand(0, cast<Value>(callee));
1055}
1056
1057Operation::operand_range CallOp::getArgOperands() {
1058 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1059}
1060
1061MutableOperandRange CallOp::getArgOperandsMutable() {
1062 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1063 getCalleeOperands().size());
1064}
1065
1066/// Verify that an inlinable callsite of a debug-info-bearing function in a
1067/// debug-info-bearing function has a debug location attached to it. This
1068/// mirrors an LLVM IR verifier.
1069static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1070 if (callee.isExternal())
1071 return success();
1072 auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
1073 if (!parentFunc)
1074 return success();
1075
1076 auto hasSubprogram = [](Operation *op) {
1077 return op->getLoc()
1078 ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1079 nullptr;
1080 };
1081 if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1082 return success();
1083 bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
1084 if (!containsLoc)
1085 return callOp.emitError()
1086 << "inlinable function call in a function with a DISubprogram "
1087 "location must have a debug location";
1088 return success();
1089}
1090
1091/// Verify that the parameter and return types of the variadic callee type match
1092/// the `callOp` argument and result types.
1093template <typename OpTy>
1094LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
1095 std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
1096 if (!varCalleeType)
1097 return success();
1098
1099 // Verify the variadic callee type is a variadic function type.
1100 if (!varCalleeType->isVarArg())
1101 return callOp.emitOpError(
1102 "expected var_callee_type to be a variadic function type");
1103
1104 // Verify the variadic callee type has at most as many parameters as the call
1105 // has argument operands.
1106 if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
1107 return callOp.emitOpError("expected var_callee_type to have at most ")
1108 << callOp.getArgOperands().size() << " parameters";
1109
1110 // Verify the variadic callee type matches the call argument types.
1111 for (auto [paramType, operand] :
1112 llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
1113 if (paramType != operand.getType())
1114 return callOp.emitOpError()
1115 << "var_callee_type parameter type mismatch: " << paramType
1116 << " != " << operand.getType();
1117
1118 // Verify the variadic callee type matches the call result type.
1119 if (!callOp.getNumResults()) {
1120 if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
1121 return callOp.emitOpError("expected var_callee_type to return void");
1122 } else {
1123 if (callOp.getResult().getType() != varCalleeType->getReturnType())
1124 return callOp.emitOpError("var_callee_type return type mismatch: ")
1125 << varCalleeType->getReturnType()
1126 << " != " << callOp.getResult().getType();
1127 }
1128 return success();
1129}
1130
1131template <typename OpType>
1132static LogicalResult verifyOperandBundles(OpType &op) {
1133 OperandRangeRange opBundleOperands = op.getOpBundleOperands();
1134 std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
1135
1136 auto isStringAttr = [](Attribute tagAttr) {
1137 return isa<StringAttr>(Val: tagAttr);
1138 };
1139 if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
1140 return op.emitError("operand bundle tag must be a StringAttr");
1141
1142 size_t numOpBundles = opBundleOperands.size();
1143 size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
1144 if (numOpBundles != numOpBundleTags)
1145 return op.emitError("expected ")
1146 << numOpBundles << " operand bundle tags, but actually got "
1147 << numOpBundleTags;
1148
1149 return success();
1150}
1151
1152LogicalResult CallOp::verify() { return verifyOperandBundles(*this); }
1153
1154LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1155 if (failed(verifyCallOpVarCalleeType(*this)))
1156 return failure();
1157
1158 // Type for the callee, we'll get it differently depending if it is a direct
1159 // or indirect call.
1160 Type fnType;
1161
1162 bool isIndirect = false;
1163
1164 // If this is an indirect call, the callee attribute is missing.
1165 FlatSymbolRefAttr calleeName = getCalleeAttr();
1166 if (!calleeName) {
1167 isIndirect = true;
1168 if (!getNumOperands())
1169 return emitOpError(
1170 "must have either a `callee` attribute or at least an operand");
1171 auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
1172 if (!ptrType)
1173 return emitOpError("indirect call expects a pointer as callee: ")
1174 << getOperand(0).getType();
1175
1176 return success();
1177 } else {
1178 Operation *callee =
1179 symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
1180 if (!callee)
1181 return emitOpError()
1182 << "'" << calleeName.getValue()
1183 << "' does not reference a symbol in the current scope";
1184 auto fn = dyn_cast<LLVMFuncOp>(callee);
1185 if (!fn)
1186 return emitOpError() << "'" << calleeName.getValue()
1187 << "' does not reference a valid LLVM function";
1188
1189 if (failed(verifyCallOpDebugInfo(*this, fn)))
1190 return failure();
1191 fnType = fn.getFunctionType();
1192 }
1193
1194 LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
1195 if (!funcType)
1196 return emitOpError("callee does not have a functional type: ") << fnType;
1197
1198 if (funcType.isVarArg() && !getVarCalleeType())
1199 return emitOpError() << "missing var_callee_type attribute for vararg call";
1200
1201 // Verify that the operand and result types match the callee.
1202
1203 if (!funcType.isVarArg() &&
1204 funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
1205 return emitOpError() << "incorrect number of operands ("
1206 << (getCalleeOperands().size() - isIndirect)
1207 << ") for callee (expecting: "
1208 << funcType.getNumParams() << ")";
1209
1210 if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
1211 return emitOpError() << "incorrect number of operands ("
1212 << (getCalleeOperands().size() - isIndirect)
1213 << ") for varargs callee (expecting at least: "
1214 << funcType.getNumParams() << ")";
1215
1216 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1217 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1218 return emitOpError() << "operand type mismatch for operand " << i << ": "
1219 << getOperand(i + isIndirect).getType()
1220 << " != " << funcType.getParamType(i);
1221
1222 if (getNumResults() == 0 &&
1223 !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1224 return emitOpError() << "expected function call to produce a value";
1225
1226 if (getNumResults() != 0 &&
1227 llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1228 return emitOpError()
1229 << "calling function with void result must not produce values";
1230
1231 if (getNumResults() > 1)
1232 return emitOpError()
1233 << "expected LLVM function call to produce 0 or 1 result";
1234
1235 if (getNumResults() && getResult().getType() != funcType.getReturnType())
1236 return emitOpError() << "result type mismatch: " << getResult().getType()
1237 << " != " << funcType.getReturnType();
1238
1239 return success();
1240}
1241
1242void CallOp::print(OpAsmPrinter &p) {
1243 auto callee = getCallee();
1244 bool isDirect = callee.has_value();
1245
1246 p << ' ';
1247
1248 // Print calling convention.
1249 if (getCConv() != LLVM::CConv::C)
1250 p << stringifyCConv(getCConv()) << ' ';
1251
1252 if (getTailCallKind() != LLVM::TailCallKind::None)
1253 p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
1254
1255 // Print the direct callee if present as a function attribute, or an indirect
1256 // callee (first operand) otherwise.
1257 if (isDirect)
1258 p.printSymbolName(callee.value());
1259 else
1260 p << getOperand(0);
1261
1262 auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
1263 p << '(' << args << ')';
1264
1265 // Print the variadic callee type if the call is variadic.
1266 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1267 p << " vararg(" << *varCalleeType << ")";
1268
1269 if (!getOpBundleOperands().empty()) {
1270 p << " ";
1271 printOpBundles(p, *this, getOpBundleOperands(),
1272 getOpBundleOperands().getTypes(), getOpBundleTags());
1273 }
1274
1275 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1276 {getCalleeAttrName(), getTailCallKindAttrName(),
1277 getVarCalleeTypeAttrName(), getCConvAttrName(),
1278 getOperandSegmentSizesAttrName(),
1279 getOpBundleSizesAttrName(),
1280 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1281 getResAttrsAttrName()});
1282
1283 p << " : ";
1284 if (!isDirect)
1285 p << getOperand(0).getType() << ", ";
1286
1287 // Reconstruct the MLIR function type from operand and result types.
1288 call_interface_impl::printFunctionSignature(
1289 p, args.getTypes(), getArgAttrsAttr(),
1290 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1291}
1292
1293/// Parses the type of a call operation and resolves the operands if the parsing
1294/// succeeds. Returns failure otherwise.
1295static ParseResult parseCallTypeAndResolveOperands(
1296 OpAsmParser &parser, OperationState &result, bool isDirect,
1297 ArrayRef<OpAsmParser::UnresolvedOperand> operands,
1298 SmallVectorImpl<DictionaryAttr> &argAttrs,
1299 SmallVectorImpl<DictionaryAttr> &resultAttrs) {
1300 SMLoc trailingTypesLoc = parser.getCurrentLocation();
1301 SmallVector<Type> types;
1302 if (parser.parseColon())
1303 return failure();
1304 if (!isDirect) {
1305 types.emplace_back();
1306 if (parser.parseType(result&: types.back()))
1307 return failure();
1308 if (parser.parseOptionalComma())
1309 return parser.emitError(
1310 loc: trailingTypesLoc, message: "expected indirect call to have 2 trailing types");
1311 }
1312 SmallVector<Type> argTypes;
1313 SmallVector<Type> resTypes;
1314 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1315 resultTypes&: resTypes, resultAttrs)) {
1316 if (isDirect)
1317 return parser.emitError(loc: trailingTypesLoc,
1318 message: "expected direct call to have 1 trailing types");
1319 return parser.emitError(loc: trailingTypesLoc,
1320 message: "expected trailing function type");
1321 }
1322
1323 if (resTypes.size() > 1)
1324 return parser.emitError(loc: trailingTypesLoc,
1325 message: "expected function with 0 or 1 result");
1326 if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(Val: resTypes[0]))
1327 return parser.emitError(loc: trailingTypesLoc,
1328 message: "expected a non-void result type");
1329
1330 // The head element of the types list matches the callee type for
1331 // indirect calls, while the types list is emtpy for direct calls.
1332 // Append the function input types to resolve the call operation
1333 // operands.
1334 llvm::append_range(C&: types, R&: argTypes);
1335 if (parser.resolveOperands(operands, types, loc: parser.getNameLoc(),
1336 result&: result.operands))
1337 return failure();
1338 if (resTypes.size() != 0)
1339 result.addTypes(newTypes: resTypes);
1340
1341 return success();
1342}
1343
1344/// Parses an optional function pointer operand before the call argument list
1345/// for indirect calls, or stops parsing at the function identifier otherwise.
1346static ParseResult parseOptionalCallFuncPtr(
1347 OpAsmParser &parser,
1348 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) {
1349 OpAsmParser::UnresolvedOperand funcPtrOperand;
1350 OptionalParseResult parseResult = parser.parseOptionalOperand(result&: funcPtrOperand);
1351 if (parseResult.has_value()) {
1352 if (failed(Result: *parseResult))
1353 return *parseResult;
1354 operands.push_back(Elt: funcPtrOperand);
1355 }
1356 return success();
1357}
1358
1359static ParseResult resolveOpBundleOperands(
1360 OpAsmParser &parser, SMLoc loc, OperationState &state,
1361 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands,
1362 ArrayRef<SmallVector<Type>> opBundleOperandTypes,
1363 StringAttr opBundleSizesAttrName) {
1364 unsigned opBundleIndex = 0;
1365 for (const auto &[operands, types] :
1366 llvm::zip_equal(t&: opBundleOperands, u&: opBundleOperandTypes)) {
1367 if (operands.size() != types.size())
1368 return parser.emitError(loc, message: "expected ")
1369 << operands.size()
1370 << " types for operand bundle operands for operand bundle #"
1371 << opBundleIndex << ", but actually got " << types.size();
1372 if (parser.resolveOperands(operands, types, loc, result&: state.operands))
1373 return failure();
1374 }
1375
1376 SmallVector<int32_t> opBundleSizes;
1377 opBundleSizes.reserve(N: opBundleOperands.size());
1378 for (const auto &operands : opBundleOperands)
1379 opBundleSizes.push_back(Elt: operands.size());
1380
1381 state.addAttribute(
1382 opBundleSizesAttrName,
1383 DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes));
1384
1385 return success();
1386}
1387
1388// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1389// `(` ssa-use-list `)`
1390// ( `vararg(` var-callee-type `)` )?
1391// ( `[` op-bundles-list `]` )?
1392// attribute-dict? `:` (type `,`)? function-type
1393ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1394 SymbolRefAttr funcAttr;
1395 TypeAttr varCalleeType;
1396 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1397 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
1398 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1399 ArrayAttr opBundleTags;
1400
1401 // Default to C Calling Convention if no keyword is provided.
1402 result.addAttribute(
1403 getCConvAttrName(result.name),
1404 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1405 parser, result, LLVM::CConv::C)));
1406
1407 result.addAttribute(
1408 getTailCallKindAttrName(result.name),
1409 TailCallKindAttr::get(parser.getContext(),
1410 parseOptionalLLVMKeyword<TailCallKind>(
1411 parser, result, LLVM::TailCallKind::None)));
1412
1413 // Parse a function pointer for indirect calls.
1414 if (parseOptionalCallFuncPtr(parser, operands))
1415 return failure();
1416 bool isDirect = operands.empty();
1417
1418 // Parse a function identifier for direct calls.
1419 if (isDirect)
1420 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1421 return failure();
1422
1423 // Parse the function arguments.
1424 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
1425 return failure();
1426
1427 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1428 if (isVarArg) {
1429 StringAttr varCalleeTypeAttrName =
1430 CallOp::getVarCalleeTypeAttrName(result.name);
1431 if (parser.parseLParen().failed() ||
1432 parser
1433 .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1434 result.attributes)
1435 .failed() ||
1436 parser.parseRParen().failed())
1437 return failure();
1438 }
1439
1440 SMLoc opBundlesLoc = parser.getCurrentLocation();
1441 if (std::optional<ParseResult> result = parseOpBundles(
1442 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1443 result && failed(*result))
1444 return failure();
1445 if (opBundleTags && !opBundleTags.empty())
1446 result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(),
1447 opBundleTags);
1448
1449 if (parser.parseOptionalAttrDict(result.attributes))
1450 return failure();
1451
1452 // Parse the trailing type list and resolve the operands.
1453 SmallVector<DictionaryAttr> argAttrs;
1454 SmallVector<DictionaryAttr> resultAttrs;
1455 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1456 argAttrs, resultAttrs))
1457 return failure();
1458 call_interface_impl::addArgAndResultAttrs(
1459 parser.getBuilder(), result, argAttrs, resultAttrs,
1460 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1461 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1462 opBundleOperandTypes,
1463 getOpBundleSizesAttrName(result.name)))
1464 return failure();
1465
1466 int32_t numOpBundleOperands = 0;
1467 for (const auto &operands : opBundleOperands)
1468 numOpBundleOperands += operands.size();
1469
1470 result.addAttribute(
1471 CallOp::getOperandSegmentSizeAttr(),
1472 parser.getBuilder().getDenseI32ArrayAttr(
1473 {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
1474 return success();
1475}
1476
1477LLVMFunctionType CallOp::getCalleeFunctionType() {
1478 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1479 return *varCalleeType;
1480 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1481}
1482
1483///===---------------------------------------------------------------------===//
1484/// LLVM::InvokeOp
1485///===---------------------------------------------------------------------===//
1486
1487void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1488 ValueRange ops, Block *normal, ValueRange normalOps,
1489 Block *unwind, ValueRange unwindOps) {
1490 auto calleeType = func.getFunctionType();
1491 build(builder, state, getCallOpResultTypes(calleeType),
1492 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
1493 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1494 nullptr, nullptr, {}, {}, normal, unwind);
1495}
1496
1497void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1498 FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1499 ValueRange normalOps, Block *unwind,
1500 ValueRange unwindOps) {
1501 build(builder, state, tys,
1502 /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr,
1503 /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {},
1504 normal, unwind);
1505}
1506
1507void InvokeOp::build(OpBuilder &builder, OperationState &state,
1508 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1509 ValueRange ops, Block *normal, ValueRange normalOps,
1510 Block *unwind, ValueRange unwindOps) {
1511 build(builder, state, getCallOpResultTypes(calleeType),
1512 getCallOpVarCalleeType(calleeType), callee, ops,
1513 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1514 nullptr, nullptr, {}, {}, normal, unwind);
1515}
1516
1517SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
1518 assert(index < getNumSuccessors() && "invalid successor index");
1519 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1520 : getUnwindDestOperandsMutable());
1521}
1522
1523CallInterfaceCallable InvokeOp::getCallableForCallee() {
1524 // Direct call.
1525 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1526 return calleeAttr;
1527 // Indirect call, callee Value is the first operand.
1528 return getOperand(0);
1529}
1530
1531void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1532 // Direct call.
1533 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1534 auto symRef = cast<SymbolRefAttr>(callee);
1535 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1536 }
1537 // Indirect call, callee Value is the first operand.
1538 return setOperand(0, cast<Value>(callee));
1539}
1540
1541Operation::operand_range InvokeOp::getArgOperands() {
1542 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1543}
1544
1545MutableOperandRange InvokeOp::getArgOperandsMutable() {
1546 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1547 getCalleeOperands().size());
1548}
1549
1550LogicalResult InvokeOp::verify() {
1551 if (failed(verifyCallOpVarCalleeType(*this)))
1552 return failure();
1553
1554 Block *unwindDest = getUnwindDest();
1555 if (unwindDest->empty())
1556 return emitError("must have at least one operation in unwind destination");
1557
1558 // In unwind destination, first operation must be LandingpadOp
1559 if (!isa<LandingpadOp>(unwindDest->front()))
1560 return emitError("first operation in unwind destination should be a "
1561 "llvm.landingpad operation");
1562
1563 if (failed(verifyOperandBundles(*this)))
1564 return failure();
1565
1566 return success();
1567}
1568
1569void InvokeOp::print(OpAsmPrinter &p) {
1570 auto callee = getCallee();
1571 bool isDirect = callee.has_value();
1572
1573 p << ' ';
1574
1575 // Print calling convention.
1576 if (getCConv() != LLVM::CConv::C)
1577 p << stringifyCConv(getCConv()) << ' ';
1578
1579 // Either function name or pointer
1580 if (isDirect)
1581 p.printSymbolName(callee.value());
1582 else
1583 p << getOperand(0);
1584
1585 p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')';
1586 p << " to ";
1587 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
1588 p << " unwind ";
1589 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
1590
1591 // Print the variadic callee type if the invoke is variadic.
1592 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1593 p << " vararg(" << *varCalleeType << ")";
1594
1595 if (!getOpBundleOperands().empty()) {
1596 p << " ";
1597 printOpBundles(p, *this, getOpBundleOperands(),
1598 getOpBundleOperands().getTypes(), getOpBundleTags());
1599 }
1600
1601 p.printOptionalAttrDict((*this)->getAttrs(),
1602 {getCalleeAttrName(), getOperandSegmentSizeAttr(),
1603 getCConvAttrName(), getVarCalleeTypeAttrName(),
1604 getOpBundleSizesAttrName(),
1605 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1606 getResAttrsAttrName()});
1607
1608 p << " : ";
1609 if (!isDirect)
1610 p << getOperand(0).getType() << ", ";
1611 call_interface_impl::printFunctionSignature(
1612 p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
1613 getArgAttrsAttr(),
1614 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1615}
1616
1617// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1618// `(` ssa-use-list `)`
1619// `to` bb-id (`[` ssa-use-and-type-list `]`)?
1620// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1621// ( `vararg(` var-callee-type `)` )?
1622// ( `[` op-bundles-list `]` )?
1623// attribute-dict? `:` (type `,`)?
1624// function-type-with-argument-attributes
1625ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1626 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
1627 SymbolRefAttr funcAttr;
1628 TypeAttr varCalleeType;
1629 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
1630 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1631 ArrayAttr opBundleTags;
1632 Block *normalDest, *unwindDest;
1633 SmallVector<Value, 4> normalOperands, unwindOperands;
1634 Builder &builder = parser.getBuilder();
1635
1636 // Default to C Calling Convention if no keyword is provided.
1637 result.addAttribute(
1638 getCConvAttrName(result.name),
1639 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1640 parser, result, LLVM::CConv::C)));
1641
1642 // Parse a function pointer for indirect calls.
1643 if (parseOptionalCallFuncPtr(parser, operands))
1644 return failure();
1645 bool isDirect = operands.empty();
1646
1647 // Parse a function identifier for direct calls.
1648 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
1649 return failure();
1650
1651 // Parse the function arguments.
1652 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1653 parser.parseKeyword("to") ||
1654 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
1655 parser.parseKeyword("unwind") ||
1656 parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
1657 return failure();
1658
1659 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1660 if (isVarArg) {
1661 StringAttr varCalleeTypeAttrName =
1662 InvokeOp::getVarCalleeTypeAttrName(result.name);
1663 if (parser.parseLParen().failed() ||
1664 parser
1665 .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1666 result.attributes)
1667 .failed() ||
1668 parser.parseRParen().failed())
1669 return failure();
1670 }
1671
1672 SMLoc opBundlesLoc = parser.getCurrentLocation();
1673 if (std::optional<ParseResult> result = parseOpBundles(
1674 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1675 result && failed(*result))
1676 return failure();
1677 if (opBundleTags && !opBundleTags.empty())
1678 result.addAttribute(
1679 InvokeOp::getOpBundleTagsAttrName(result.name).getValue(),
1680 opBundleTags);
1681
1682 if (parser.parseOptionalAttrDict(result.attributes))
1683 return failure();
1684
1685 // Parse the trailing type list and resolve the function operands.
1686 SmallVector<DictionaryAttr> argAttrs;
1687 SmallVector<DictionaryAttr> resultAttrs;
1688 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1689 argAttrs, resultAttrs))
1690 return failure();
1691 call_interface_impl::addArgAndResultAttrs(
1692 parser.getBuilder(), result, argAttrs, resultAttrs,
1693 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1694
1695 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1696 opBundleOperandTypes,
1697 getOpBundleSizesAttrName(result.name)))
1698 return failure();
1699
1700 result.addSuccessors({normalDest, unwindDest});
1701 result.addOperands(normalOperands);
1702 result.addOperands(unwindOperands);
1703
1704 int32_t numOpBundleOperands = 0;
1705 for (const auto &operands : opBundleOperands)
1706 numOpBundleOperands += operands.size();
1707
1708 result.addAttribute(
1709 InvokeOp::getOperandSegmentSizeAttr(),
1710 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()),
1711 static_cast<int32_t>(normalOperands.size()),
1712 static_cast<int32_t>(unwindOperands.size()),
1713 numOpBundleOperands}));
1714 return success();
1715}
1716
1717LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1718 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1719 return *varCalleeType;
1720 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1721}
1722
1723///===----------------------------------------------------------------------===//
1724/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1725///===----------------------------------------------------------------------===//
1726
1727LogicalResult LandingpadOp::verify() {
1728 Value value;
1729 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1730 if (!func.getPersonality())
1731 return emitError(
1732 "llvm.landingpad needs to be in a function with a personality");
1733 }
1734
1735 // Consistency of llvm.landingpad result types is checked in
1736 // LLVMFuncOp::verify().
1737
1738 if (!getCleanup() && getOperands().empty())
1739 return emitError("landingpad instruction expects at least one clause or "
1740 "cleanup attribute");
1741
1742 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1743 value = getOperand(idx);
1744 bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
1745 if (isFilter) {
1746 // FIXME: Verify filter clauses when arrays are appropriately handled
1747 } else {
1748 // catch - global addresses only.
1749 // Bitcast ops should have global addresses as their args.
1750 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1751 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1752 continue;
1753 return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1754 << "global addresses expected as operand to "
1755 "bitcast used in clauses for landingpad";
1756 }
1757 // ZeroOp and AddressOfOp allowed
1758 if (value.getDefiningOp<ZeroOp>())
1759 continue;
1760 if (value.getDefiningOp<AddressOfOp>())
1761 continue;
1762 return emitError("clause #")
1763 << idx << " is not a known constant - null, addressof, bitcast";
1764 }
1765 }
1766 return success();
1767}
1768
1769void LandingpadOp::print(OpAsmPrinter &p) {
1770 p << (getCleanup() ? " cleanup " : " ");
1771
1772 // Clauses
1773 for (auto value : getOperands()) {
1774 // Similar to llvm - if clause is an array type then it is filter
1775 // clause else catch clause
1776 bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
1777 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1778 << value.getType() << ") ";
1779 }
1780
1781 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1782
1783 p << ": " << getType();
1784}
1785
1786// <operation> ::= `llvm.landingpad` `cleanup`?
1787// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1788ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1789 // Check for cleanup
1790 if (succeeded(parser.parseOptionalKeyword("cleanup")))
1791 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1792
1793 // Parse clauses with types
1794 while (succeeded(parser.parseOptionalLParen()) &&
1795 (succeeded(parser.parseOptionalKeyword("filter")) ||
1796 succeeded(parser.parseOptionalKeyword("catch")))) {
1797 OpAsmParser::UnresolvedOperand operand;
1798 Type ty;
1799 if (parser.parseOperand(operand) || parser.parseColon() ||
1800 parser.parseType(ty) ||
1801 parser.resolveOperand(operand, ty, result.operands) ||
1802 parser.parseRParen())
1803 return failure();
1804 }
1805
1806 Type type;
1807 if (parser.parseColon() || parser.parseType(type))
1808 return failure();
1809
1810 result.addTypes(type);
1811 return success();
1812}
1813
1814//===----------------------------------------------------------------------===//
1815// ExtractValueOp
1816//===----------------------------------------------------------------------===//
1817
1818/// Extract the type at `position` in the LLVM IR aggregate type
1819/// `containerType`. Each element of `position` is an index into a nested
1820/// aggregate type. Return the resulting type or emit an error.
1821static Type getInsertExtractValueElementType(
1822 function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1823 ArrayRef<int64_t> position) {
1824 Type llvmType = containerType;
1825 if (!isCompatibleType(type: containerType)) {
1826 emitError("expected LLVM IR Dialect type, got ") << containerType;
1827 return {};
1828 }
1829
1830 // Infer the element type from the structure type: iteratively step inside the
1831 // type by taking the element type, indexed by the position attribute for
1832 // structures. Check the position index before accessing, it is supposed to
1833 // be in bounds.
1834 for (int64_t idx : position) {
1835 if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
1836 if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1837 emitError("position out of bounds: ") << idx;
1838 return {};
1839 }
1840 llvmType = arrayType.getElementType();
1841 } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
1842 if (idx < 0 ||
1843 static_cast<unsigned>(idx) >= structType.getBody().size()) {
1844 emitError("position out of bounds: ") << idx;
1845 return {};
1846 }
1847 llvmType = structType.getBody()[idx];
1848 } else {
1849 emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1850 return {};
1851 }
1852 }
1853 return llvmType;
1854}
1855
1856/// Extract the type at `position` in the wrapped LLVM IR aggregate type
1857/// `containerType`.
1858static Type getInsertExtractValueElementType(Type llvmType,
1859 ArrayRef<int64_t> position) {
1860 for (int64_t idx : position) {
1861 if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
1862 llvmType = structType.getBody()[idx];
1863 else
1864 llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
1865 }
1866 return llvmType;
1867}
1868
1869OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1870 if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
1871 SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
1872 newPos.append(getPosition().begin(), getPosition().end());
1873 setPosition(newPos);
1874 getContainerMutable().set(extractValueOp.getContainer());
1875 return getResult();
1876 }
1877
1878 {
1879 DenseElementsAttr constval;
1880 matchPattern(getContainer(), m_Constant(&constval));
1881 if (constval && constval.getElementType() == getType()) {
1882 if (isa<SplatElementsAttr>(constval))
1883 return constval.getSplatValue<Attribute>();
1884 if (getPosition().size() == 1)
1885 return constval.getValues<Attribute>()[getPosition()[0]];
1886 }
1887 }
1888
1889 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1890 OpFoldResult result = {};
1891 ArrayRef<int64_t> extractPos = getPosition();
1892 bool switchedToInsertedValue = false;
1893 while (insertValueOp) {
1894 ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
1895 auto extractPosSize = extractPos.size();
1896 auto insertPosSize = insertPos.size();
1897
1898 // Case 1: Exact match of positions.
1899 if (extractPos == insertPos)
1900 return insertValueOp.getValue();
1901
1902 // Case 2: Insert position is a prefix of extract position. Continue
1903 // traversal with the inserted value. Example:
1904 // ```
1905 // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
1906 // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
1907 // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
1908 // %3 = llvm.insertvalue %2, %foo[0]
1909 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1910 // %4 = llvm.extractvalue %3[0, 0]
1911 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1912 // ```
1913 // In the above example, %4 is folded to %arg1.
1914 if (extractPosSize > insertPosSize &&
1915 extractPos.take_front(insertPosSize) == insertPos) {
1916 insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
1917 extractPos = extractPos.drop_front(insertPosSize);
1918 switchedToInsertedValue = true;
1919 continue;
1920 }
1921
1922 // Case 3: Try to continue the traversal with the container value.
1923 unsigned min = std::min(extractPosSize, insertPosSize);
1924
1925 // If one is fully prefix of the other, stop propagating back as it will
1926 // miss dependencies. For instance, %3 should not fold to %f0 in the
1927 // following example:
1928 // ```
1929 // %1 = llvm.insertvalue %f0, %0[0, 0] :
1930 // !llvm.array<4 x !llvm.array<4 x f32>>
1931 // %2 = llvm.insertvalue %arr, %1[0] :
1932 // !llvm.array<4 x !llvm.array<4 x f32>>
1933 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
1934 // ```
1935 if (extractPos.take_front(min) == insertPos.take_front(min))
1936 return result;
1937 // If neither a prefix, nor the exact position, we can extract out of the
1938 // value being inserted into. Moreover, we can try again if that operand
1939 // is itself an insertvalue expression.
1940 if (!switchedToInsertedValue) {
1941 // Do not swap out the container operand if we decided earlier to
1942 // continue the traversal with the inserted value (Case 2).
1943 getContainerMutable().assign(insertValueOp.getContainer());
1944 result = getResult();
1945 }
1946 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1947 }
1948 return result;
1949}
1950
1951LogicalResult ExtractValueOp::verify() {
1952 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1953 Type valueType = getInsertExtractValueElementType(
1954 emitError, getContainer().getType(), getPosition());
1955 if (!valueType)
1956 return failure();
1957
1958 if (getRes().getType() != valueType)
1959 return emitOpError() << "Type mismatch: extracting from "
1960 << getContainer().getType() << " should produce "
1961 << valueType << " but this op returns "
1962 << getRes().getType();
1963 return success();
1964}
1965
1966void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
1967 Value container, ArrayRef<int64_t> position) {
1968 build(builder, state,
1969 getInsertExtractValueElementType(container.getType(), position),
1970 container, builder.getAttr<DenseI64ArrayAttr>(position));
1971}
1972
1973//===----------------------------------------------------------------------===//
1974// InsertValueOp
1975//===----------------------------------------------------------------------===//
1976
1977/// Infer the value type from the container type and position.
1978static ParseResult
1979parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,
1980 Type containerType,
1981 DenseI64ArrayAttr position) {
1982 valueType = getInsertExtractValueElementType(
1983 [&](StringRef msg) {
1984 return parser.emitError(loc: parser.getCurrentLocation(), message: msg);
1985 },
1986 containerType, position.asArrayRef());
1987 return success(IsSuccess: !!valueType);
1988}
1989
1990/// Nothing to print for an inferred type.
1991static void printInsertExtractValueElementType(AsmPrinter &printer,
1992 Operation *op, Type valueType,
1993 Type containerType,
1994 DenseI64ArrayAttr position) {}
1995
1996LogicalResult InsertValueOp::verify() {
1997 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1998 Type valueType = getInsertExtractValueElementType(
1999 emitError, getContainer().getType(), getPosition());
2000 if (!valueType)
2001 return failure();
2002
2003 if (getValue().getType() != valueType)
2004 return emitOpError() << "Type mismatch: cannot insert "
2005 << getValue().getType() << " into "
2006 << getContainer().getType();
2007
2008 return success();
2009}
2010
2011//===----------------------------------------------------------------------===//
2012// ReturnOp
2013//===----------------------------------------------------------------------===//
2014
2015LogicalResult ReturnOp::verify() {
2016 auto parent = (*this)->getParentOfType<LLVMFuncOp>();
2017 if (!parent)
2018 return success();
2019
2020 Type expectedType = parent.getFunctionType().getReturnType();
2021 if (llvm::isa<LLVMVoidType>(expectedType)) {
2022 if (!getArg())
2023 return success();
2024 InFlightDiagnostic diag = emitOpError("expected no operands");
2025 diag.attachNote(parent->getLoc()) << "when returning from function";
2026 return diag;
2027 }
2028 if (!getArg()) {
2029 if (llvm::isa<LLVMVoidType>(expectedType))
2030 return success();
2031 InFlightDiagnostic diag = emitOpError("expected 1 operand");
2032 diag.attachNote(parent->getLoc()) << "when returning from function";
2033 return diag;
2034 }
2035 if (expectedType != getArg().getType()) {
2036 InFlightDiagnostic diag = emitOpError("mismatching result types");
2037 diag.attachNote(parent->getLoc()) << "when returning from function";
2038 return diag;
2039 }
2040 return success();
2041}
2042
2043//===----------------------------------------------------------------------===//
2044// LLVM::AddressOfOp.
2045//===----------------------------------------------------------------------===//
2046
2047static Operation *parentLLVMModule(Operation *op) {
2048 Operation *module = op->getParentOp();
2049 while (module && !satisfiesLLVMModule(op: module))
2050 module = module->getParentOp();
2051 assert(module && "unexpected operation outside of a module");
2052 return module;
2053}
2054
2055GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
2056 return dyn_cast_or_null<GlobalOp>(
2057 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2058}
2059
2060LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
2061 return dyn_cast_or_null<LLVMFuncOp>(
2062 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2063}
2064
2065AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
2066 return dyn_cast_or_null<AliasOp>(
2067 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2068}
2069
2070LogicalResult
2071AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2072 Operation *symbol =
2073 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
2074
2075 auto global = dyn_cast_or_null<GlobalOp>(symbol);
2076 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2077 auto alias = dyn_cast_or_null<AliasOp>(symbol);
2078
2079 if (!global && !function && !alias)
2080 return emitOpError("must reference a global defined by 'llvm.mlir.global', "
2081 "'llvm.mlir.alias' or 'llvm.func'");
2082
2083 LLVMPointerType type = getType();
2084 if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
2085 (alias && alias.getAddrSpace() != type.getAddressSpace()))
2086 return emitOpError("pointer address space must match address space of the "
2087 "referenced global or alias");
2088
2089 return success();
2090}
2091
2092// AddressOfOp constant-folds to the global symbol name.
2093OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
2094 return getGlobalNameAttr();
2095}
2096
2097//===----------------------------------------------------------------------===//
2098// LLVM::DSOLocalEquivalentOp
2099//===----------------------------------------------------------------------===//
2100
2101LLVMFuncOp
2102DSOLocalEquivalentOp::getFunction(SymbolTableCollection &symbolTable) {
2103 return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
2104 parentLLVMModule(*this), getFunctionNameAttr()));
2105}
2106
2107AliasOp DSOLocalEquivalentOp::getAlias(SymbolTableCollection &symbolTable) {
2108 return dyn_cast_or_null<AliasOp>(symbolTable.lookupSymbolIn(
2109 parentLLVMModule(*this), getFunctionNameAttr()));
2110}
2111
2112LogicalResult
2113DSOLocalEquivalentOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2114 Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
2115 getFunctionNameAttr());
2116 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2117 auto alias = dyn_cast_or_null<AliasOp>(symbol);
2118
2119 if (!function && !alias)
2120 return emitOpError(
2121 "must reference a global defined by 'llvm.func' or 'llvm.mlir.alias'");
2122
2123 if (alias) {
2124 if (alias.getInitializer()
2125 .walk([&](AddressOfOp addrOp) {
2126 if (addrOp.getGlobal(symbolTable))
2127 return WalkResult::interrupt();
2128 return WalkResult::advance();
2129 })
2130 .wasInterrupted())
2131 return emitOpError("must reference an alias to a function");
2132 }
2133
2134 if ((function && function.getLinkage() == LLVM::Linkage::ExternWeak) ||
2135 (alias && alias.getLinkage() == LLVM::Linkage::ExternWeak))
2136 return emitOpError(
2137 "target function with 'extern_weak' linkage not allowed");
2138
2139 return success();
2140}
2141
2142/// Fold a dso_local_equivalent operation to a dedicated dso_local_equivalent
2143/// attribute.
2144OpFoldResult DSOLocalEquivalentOp::fold(FoldAdaptor) {
2145 return DSOLocalEquivalentAttr::get(getContext(), getFunctionNameAttr());
2146}
2147
2148//===----------------------------------------------------------------------===//
2149// Verifier for LLVM::ComdatOp.
2150//===----------------------------------------------------------------------===//
2151
2152void ComdatOp::build(OpBuilder &builder, OperationState &result,
2153 StringRef symName) {
2154 result.addAttribute(getSymNameAttrName(result.name),
2155 builder.getStringAttr(symName));
2156 Region *body = result.addRegion();
2157 body->emplaceBlock();
2158}
2159
2160LogicalResult ComdatOp::verifyRegions() {
2161 Region &body = getBody();
2162 for (Operation &op : body.getOps())
2163 if (!isa<ComdatSelectorOp>(op))
2164 return op.emitError(
2165 "only comdat selector symbols can appear in a comdat region");
2166
2167 return success();
2168}
2169
2170//===----------------------------------------------------------------------===//
2171// Builder, printer and verifier for LLVM::GlobalOp.
2172//===----------------------------------------------------------------------===//
2173
2174void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
2175 bool isConstant, Linkage linkage, StringRef name,
2176 Attribute value, uint64_t alignment, unsigned addrSpace,
2177 bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
2178 ArrayRef<NamedAttribute> attrs,
2179 ArrayRef<Attribute> dbgExprs) {
2180 result.addAttribute(getSymNameAttrName(result.name),
2181 builder.getStringAttr(name));
2182 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
2183 if (isConstant)
2184 result.addAttribute(getConstantAttrName(result.name),
2185 builder.getUnitAttr());
2186 if (value)
2187 result.addAttribute(getValueAttrName(result.name), value);
2188 if (dsoLocal)
2189 result.addAttribute(getDsoLocalAttrName(result.name),
2190 builder.getUnitAttr());
2191 if (threadLocal)
2192 result.addAttribute(getThreadLocal_AttrName(result.name),
2193 builder.getUnitAttr());
2194 if (comdat)
2195 result.addAttribute(getComdatAttrName(result.name), comdat);
2196
2197 // Only add an alignment attribute if the "alignment" input
2198 // is different from 0. The value must also be a power of two, but
2199 // this is tested in GlobalOp::verify, not here.
2200 if (alignment != 0)
2201 result.addAttribute(getAlignmentAttrName(result.name),
2202 builder.getI64IntegerAttr(alignment));
2203
2204 result.addAttribute(getLinkageAttrName(result.name),
2205 LinkageAttr::get(builder.getContext(), linkage));
2206 if (addrSpace != 0)
2207 result.addAttribute(getAddrSpaceAttrName(result.name),
2208 builder.getI32IntegerAttr(addrSpace));
2209 result.attributes.append(attrs.begin(), attrs.end());
2210
2211 if (!dbgExprs.empty())
2212 result.addAttribute(getDbgExprsAttrName(result.name),
2213 ArrayAttr::get(builder.getContext(), dbgExprs));
2214
2215 result.addRegion();
2216}
2217
2218template <typename OpType>
2219static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op) {
2220 p << ' ' << stringifyLinkage(op.getLinkage()) << ' ';
2221 StringRef visibility = stringifyVisibility(op.getVisibility_());
2222 if (!visibility.empty())
2223 p << visibility << ' ';
2224 if (op.getThreadLocal_())
2225 p << "thread_local ";
2226 if (auto unnamedAddr = op.getUnnamedAddr()) {
2227 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2228 if (!str.empty())
2229 p << str << ' ';
2230 }
2231}
2232
2233void GlobalOp::print(OpAsmPrinter &p) {
2234 printCommonGlobalAndAlias<GlobalOp>(p, *this);
2235 if (getConstant())
2236 p << "constant ";
2237 p.printSymbolName(getSymName());
2238 p << '(';
2239 if (auto value = getValueOrNull())
2240 p.printAttribute(value);
2241 p << ')';
2242 if (auto comdat = getComdat())
2243 p << " comdat(" << *comdat << ')';
2244
2245 // Note that the alignment attribute is printed using the
2246 // default syntax here, even though it is an inherent attribute
2247 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
2248 p.printOptionalAttrDict((*this)->getAttrs(),
2249 {SymbolTable::getSymbolAttrName(),
2250 getGlobalTypeAttrName(), getConstantAttrName(),
2251 getValueAttrName(), getLinkageAttrName(),
2252 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2253 getVisibility_AttrName(), getComdatAttrName(),
2254 getUnnamedAddrAttrName()});
2255
2256 // Print the trailing type unless it's a string global.
2257 if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
2258 return;
2259 p << " : " << getType();
2260
2261 Region &initializer = getInitializerRegion();
2262 if (!initializer.empty()) {
2263 p << ' ';
2264 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
2265 }
2266}
2267
2268static LogicalResult verifyComdat(Operation *op,
2269 std::optional<SymbolRefAttr> attr) {
2270 if (!attr)
2271 return success();
2272
2273 auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
2274 if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
2275 return op->emitError() << "expected comdat symbol";
2276
2277 return success();
2278}
2279
2280static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
2281 llvm::DenseSet<BlockTagAttr> blockTags;
2282 // Note that presence of `BlockTagOp`s currently can't prevent an unrecheable
2283 // block to be removed by canonicalizer's region simplify pass, which needs to
2284 // be dialect aware to allow extra constraints to be described.
2285 WalkResult res = funcOp.walk([&](BlockTagOp blockTagOp) {
2286 if (blockTags.contains(blockTagOp.getTag())) {
2287 blockTagOp.emitError()
2288 << "duplicate block tag '" << blockTagOp.getTag().getId()
2289 << "' in the same function: ";
2290 return WalkResult::interrupt();
2291 }
2292 blockTags.insert(blockTagOp.getTag());
2293 return WalkResult::advance();
2294 });
2295
2296 return failure(IsFailure: res.wasInterrupted());
2297}
2298
2299/// Parse common attributes that might show up in the same order in both
2300/// GlobalOp and AliasOp.
2301template <typename OpType>
2302static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser,
2303 OperationState &result) {
2304 MLIRContext *ctx = parser.getContext();
2305 // Parse optional linkage, default to External.
2306 result.addAttribute(OpType::getLinkageAttrName(result.name),
2307 LLVM::LinkageAttr::get(
2308 ctx, parseOptionalLLVMKeyword<Linkage>(
2309 parser, result, LLVM::Linkage::External)));
2310
2311 // Parse optional visibility, default to Default.
2312 result.addAttribute(OpType::getVisibility_AttrName(result.name),
2313 parser.getBuilder().getI64IntegerAttr(
2314 parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2315 parser, result, LLVM::Visibility::Default)));
2316
2317 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "thread_local")))
2318 result.addAttribute(OpType::getThreadLocal_AttrName(result.name),
2319 parser.getBuilder().getUnitAttr());
2320
2321 // Parse optional UnnamedAddr, default to None.
2322 result.addAttribute(OpType::getUnnamedAddrAttrName(result.name),
2323 parser.getBuilder().getI64IntegerAttr(
2324 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2325 parser, result, LLVM::UnnamedAddr::None)));
2326
2327 return success();
2328}
2329
2330// operation ::= `llvm.mlir.global` linkage? visibility?
2331// (`unnamed_addr` | `local_unnamed_addr`)?
2332// `thread_local`? `constant`? `@` identifier
2333// `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
2334// attribute-list? (`:` type)? region?
2335//
2336// The type can be omitted for string attributes, in which case it will be
2337// inferred from the value of the string as [strlen(value) x i8].
2338ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
2339 // Call into common parsing between GlobalOp and AliasOp.
2340 if (parseCommonGlobalAndAlias<GlobalOp>(parser, result).failed())
2341 return failure();
2342
2343 if (succeeded(parser.parseOptionalKeyword("constant")))
2344 result.addAttribute(getConstantAttrName(result.name),
2345 parser.getBuilder().getUnitAttr());
2346
2347 StringAttr name;
2348 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2349 result.attributes) ||
2350 parser.parseLParen())
2351 return failure();
2352
2353 Attribute value;
2354 if (parser.parseOptionalRParen()) {
2355 if (parser.parseAttribute(value, getValueAttrName(result.name),
2356 result.attributes) ||
2357 parser.parseRParen())
2358 return failure();
2359 }
2360
2361 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2362 SymbolRefAttr comdat;
2363 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2364 parser.parseRParen())
2365 return failure();
2366
2367 result.addAttribute(getComdatAttrName(result.name), comdat);
2368 }
2369
2370 SmallVector<Type, 1> types;
2371 if (parser.parseOptionalAttrDict(result.attributes) ||
2372 parser.parseOptionalColonTypeList(types))
2373 return failure();
2374
2375 if (types.size() > 1)
2376 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2377
2378 Region &initRegion = *result.addRegion();
2379 if (types.empty()) {
2380 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
2381 MLIRContext *context = parser.getContext();
2382 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
2383 strAttr.getValue().size());
2384 types.push_back(arrayType);
2385 } else {
2386 return parser.emitError(parser.getNameLoc(),
2387 "type can only be omitted for string globals");
2388 }
2389 } else {
2390 OptionalParseResult parseResult =
2391 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
2392 /*argTypes=*/{});
2393 if (parseResult.has_value() && failed(*parseResult))
2394 return failure();
2395 }
2396
2397 result.addAttribute(getGlobalTypeAttrName(result.name),
2398 TypeAttr::get(types[0]));
2399 return success();
2400}
2401
2402static bool isZeroAttribute(Attribute value) {
2403 if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
2404 return intValue.getValue().isZero();
2405 if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
2406 return fpValue.getValue().isZero();
2407 if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(Val&: value))
2408 return isZeroAttribute(value: splatValue.getSplatValue<Attribute>());
2409 if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
2410 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
2411 if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
2412 return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
2413 return false;
2414}
2415
2416LogicalResult GlobalOp::verify() {
2417 bool validType = isCompatibleOuterType(getType())
2418 ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2419 LLVMMetadataType, LLVMLabelType>(getType())
2420 : llvm::isa<PointerElementTypeInterface>(getType());
2421 if (!validType)
2422 return emitOpError(
2423 "expects type to be a valid element type for an LLVM global");
2424 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
2425 return emitOpError("must appear at the module level");
2426
2427 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
2428 auto type = llvm::dyn_cast<LLVMArrayType>(getType());
2429 IntegerType elementType =
2430 type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
2431 if (!elementType || elementType.getWidth() != 8 ||
2432 type.getNumElements() != strAttr.getValue().size())
2433 return emitOpError(
2434 "requires an i8 array type of the length equal to that of the string "
2435 "attribute");
2436 }
2437
2438 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2439 if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
2440 return emitOpError()
2441 << "this target extension type cannot be used in a global";
2442
2443 if (Attribute value = getValueOrNull())
2444 return emitOpError() << "global with target extension type can only be "
2445 "initialized with zero-initializer";
2446 }
2447
2448 if (getLinkage() == Linkage::Common) {
2449 if (Attribute value = getValueOrNull()) {
2450 if (!isZeroAttribute(value)) {
2451 return emitOpError()
2452 << "expected zero value for '"
2453 << stringifyLinkage(Linkage::Common) << "' linkage";
2454 }
2455 }
2456 }
2457
2458 if (getLinkage() == Linkage::Appending) {
2459 if (!llvm::isa<LLVMArrayType>(getType())) {
2460 return emitOpError() << "expected array type for '"
2461 << stringifyLinkage(Linkage::Appending)
2462 << "' linkage";
2463 }
2464 }
2465
2466 if (failed(verifyComdat(*this, getComdat())))
2467 return failure();
2468
2469 std::optional<uint64_t> alignAttr = getAlignment();
2470 if (alignAttr.has_value()) {
2471 uint64_t value = alignAttr.value();
2472 if (!llvm::isPowerOf2_64(value))
2473 return emitError() << "alignment attribute is not a power of 2";
2474 }
2475
2476 return success();
2477}
2478
2479LogicalResult GlobalOp::verifyRegions() {
2480 if (Block *b = getInitializerBlock()) {
2481 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2482 if (ret.operand_type_begin() == ret.operand_type_end())
2483 return emitOpError("initializer region cannot return void");
2484 if (*ret.operand_type_begin() != getType())
2485 return emitOpError("initializer region type ")
2486 << *ret.operand_type_begin() << " does not match global type "
2487 << getType();
2488
2489 for (Operation &op : *b) {
2490 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2491 if (!iface || !iface.hasNoEffect())
2492 return op.emitError()
2493 << "ops with side effects not allowed in global initializers";
2494 }
2495
2496 if (getValueOrNull())
2497 return emitOpError("cannot have both initializer value and region");
2498 }
2499
2500 return success();
2501}
2502
2503//===----------------------------------------------------------------------===//
2504// LLVM::GlobalCtorsOp
2505//===----------------------------------------------------------------------===//
2506
2507LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
2508 if (data.empty())
2509 return success();
2510
2511 if (llvm::all_of(data.getAsRange<Attribute>(), [](Attribute v) {
2512 return isa<FlatSymbolRefAttr, ZeroAttr>(v);
2513 }))
2514 return success();
2515 return op->emitError(message: "data element must be symbol or #llvm.zero");
2516}
2517
2518LogicalResult
2519GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2520 for (Attribute ctor : getCtors()) {
2521 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
2522 symbolTable)))
2523 return failure();
2524 }
2525 return success();
2526}
2527
2528LogicalResult GlobalCtorsOp::verify() {
2529 if (checkGlobalXtorData(*this, getData()).failed())
2530 return failure();
2531
2532 if (getCtors().size() == getPriorities().size() &&
2533 getCtors().size() == getData().size())
2534 return success();
2535 return emitError(
2536 "ctors, priorities, and data must have the same number of elements");
2537}
2538
2539//===----------------------------------------------------------------------===//
2540// LLVM::GlobalDtorsOp
2541//===----------------------------------------------------------------------===//
2542
2543LogicalResult
2544GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2545 for (Attribute dtor : getDtors()) {
2546 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
2547 symbolTable)))
2548 return failure();
2549 }
2550 return success();
2551}
2552
2553LogicalResult GlobalDtorsOp::verify() {
2554 if (checkGlobalXtorData(*this, getData()).failed())
2555 return failure();
2556
2557 if (getDtors().size() == getPriorities().size() &&
2558 getDtors().size() == getData().size())
2559 return success();
2560 return emitError(
2561 "dtors, priorities, and data must have the same number of elements");
2562}
2563
2564//===----------------------------------------------------------------------===//
2565// Builder, printer and verifier for LLVM::AliasOp.
2566//===----------------------------------------------------------------------===//
2567
2568void AliasOp::build(OpBuilder &builder, OperationState &result, Type type,
2569 Linkage linkage, StringRef name, bool dsoLocal,
2570 bool threadLocal, ArrayRef<NamedAttribute> attrs) {
2571 result.addAttribute(getSymNameAttrName(result.name),
2572 builder.getStringAttr(name));
2573 result.addAttribute(getAliasTypeAttrName(result.name), TypeAttr::get(type));
2574 if (dsoLocal)
2575 result.addAttribute(getDsoLocalAttrName(result.name),
2576 builder.getUnitAttr());
2577 if (threadLocal)
2578 result.addAttribute(getThreadLocal_AttrName(result.name),
2579 builder.getUnitAttr());
2580
2581 result.addAttribute(getLinkageAttrName(result.name),
2582 LinkageAttr::get(builder.getContext(), linkage));
2583 result.attributes.append(attrs.begin(), attrs.end());
2584
2585 result.addRegion();
2586}
2587
2588void AliasOp::print(OpAsmPrinter &p) {
2589 printCommonGlobalAndAlias<AliasOp>(p, *this);
2590
2591 p.printSymbolName(getSymName());
2592 p.printOptionalAttrDict((*this)->getAttrs(),
2593 {SymbolTable::getSymbolAttrName(),
2594 getAliasTypeAttrName(), getLinkageAttrName(),
2595 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2596 getVisibility_AttrName(), getUnnamedAddrAttrName()});
2597
2598 // Print the trailing type.
2599 p << " : " << getType() << ' ';
2600 // Print the initializer region.
2601 p.printRegion(getInitializerRegion(), /*printEntryBlockArgs=*/false);
2602}
2603
2604// operation ::= `llvm.mlir.alias` linkage? visibility?
2605// (`unnamed_addr` | `local_unnamed_addr`)?
2606// `thread_local`? `@` identifier
2607// `(` attribute? `)`
2608// attribute-list? `:` type region
2609//
2610ParseResult AliasOp::parse(OpAsmParser &parser, OperationState &result) {
2611 // Call into common parsing between GlobalOp and AliasOp.
2612 if (parseCommonGlobalAndAlias<AliasOp>(parser, result).failed())
2613 return failure();
2614
2615 StringAttr name;
2616 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2617 result.attributes))
2618 return failure();
2619
2620 SmallVector<Type, 1> types;
2621 if (parser.parseOptionalAttrDict(result.attributes) ||
2622 parser.parseOptionalColonTypeList(types))
2623 return failure();
2624
2625 if (types.size() > 1)
2626 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2627
2628 Region &initRegion = *result.addRegion();
2629 if (parser.parseRegion(initRegion).failed())
2630 return failure();
2631
2632 result.addAttribute(getAliasTypeAttrName(result.name),
2633 TypeAttr::get(types[0]));
2634 return success();
2635}
2636
2637LogicalResult AliasOp::verify() {
2638 bool validType = isCompatibleOuterType(getType())
2639 ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2640 LLVMMetadataType, LLVMLabelType>(getType())
2641 : llvm::isa<PointerElementTypeInterface>(getType());
2642 if (!validType)
2643 return emitOpError(
2644 "expects type to be a valid element type for an LLVM global alias");
2645
2646 // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2647 switch (getLinkage()) {
2648 case Linkage::External:
2649 case Linkage::Internal:
2650 case Linkage::Private:
2651 case Linkage::Weak:
2652 case Linkage::WeakODR:
2653 case Linkage::Linkonce:
2654 case Linkage::LinkonceODR:
2655 case Linkage::AvailableExternally:
2656 break;
2657 default:
2658 return emitOpError()
2659 << "'" << stringifyLinkage(getLinkage())
2660 << "' linkage not supported in aliases, available options: private, "
2661 "internal, linkonce, weak, linkonce_odr, weak_odr, external or "
2662 "available_externally";
2663 }
2664
2665 return success();
2666}
2667
2668LogicalResult AliasOp::verifyRegions() {
2669 Block &b = getInitializerBlock();
2670 auto ret = cast<ReturnOp>(b.getTerminator());
2671 if (ret.getNumOperands() == 0 ||
2672 !isa<LLVM::LLVMPointerType>(ret.getOperand(0).getType()))
2673 return emitOpError("initializer region must always return a pointer");
2674
2675 for (Operation &op : b) {
2676 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2677 if (!iface || !iface.hasNoEffect())
2678 return op.emitError()
2679 << "ops with side effects are not allowed in alias initializers";
2680 }
2681
2682 return success();
2683}
2684
2685unsigned AliasOp::getAddrSpace() {
2686 Block &initializer = getInitializerBlock();
2687 auto ret = cast<ReturnOp>(initializer.getTerminator());
2688 auto ptrTy = cast<LLVMPointerType>(ret.getOperand(0).getType());
2689 return ptrTy.getAddressSpace();
2690}
2691
2692//===----------------------------------------------------------------------===//
2693// ShuffleVectorOp
2694//===----------------------------------------------------------------------===//
2695
2696void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2697 Value v2, DenseI32ArrayAttr mask,
2698 ArrayRef<NamedAttribute> attrs) {
2699 auto containerType = v1.getType();
2700 auto vType = LLVM::getVectorType(
2701 cast<VectorType>(containerType).getElementType(), mask.size(),
2702 LLVM::isScalableVectorType(containerType));
2703 build(builder, state, vType, v1, v2, mask);
2704 state.addAttributes(attrs);
2705}
2706
2707void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2708 Value v2, ArrayRef<int32_t> mask) {
2709 build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
2710}
2711
2712/// Build the result type of a shuffle vector operation.
2713static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2714 Type &resType, DenseI32ArrayAttr mask) {
2715 if (!LLVM::isCompatibleVectorType(type: v1Type))
2716 return parser.emitError(loc: parser.getCurrentLocation(),
2717 message: "expected an LLVM compatible vector type");
2718 resType =
2719 LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
2720 mask.size(), LLVM::isScalableVectorType(vectorType: v1Type));
2721 return success();
2722}
2723
2724/// Nothing to do when the result type is inferred.
2725static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2726 Type resType, DenseI32ArrayAttr mask) {}
2727
2728LogicalResult ShuffleVectorOp::verify() {
2729 if (LLVM::isScalableVectorType(getV1().getType()) &&
2730 llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
2731 return emitOpError("expected a splat operation for scalable vectors");
2732 return success();
2733}
2734
2735//===----------------------------------------------------------------------===//
2736// Implementations for LLVM::LLVMFuncOp.
2737//===----------------------------------------------------------------------===//
2738
2739// Add the entry block to the function.
2740Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
2741 assert(empty() && "function already has an entry block");
2742 OpBuilder::InsertionGuard g(builder);
2743 Block *entry = builder.createBlock(&getBody());
2744
2745 // FIXME: Allow passing in proper locations for the entry arguments.
2746 LLVMFunctionType type = getFunctionType();
2747 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2748 entry->addArgument(type.getParamType(i), getLoc());
2749 return entry;
2750}
2751
2752void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2753 StringRef name, Type type, LLVM::Linkage linkage,
2754 bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
2755 ArrayRef<NamedAttribute> attrs,
2756 ArrayRef<DictionaryAttr> argAttrs,
2757 std::optional<uint64_t> functionEntryCount) {
2758 result.addRegion();
2759 result.addAttribute(SymbolTable::getSymbolAttrName(),
2760 builder.getStringAttr(name));
2761 result.addAttribute(getFunctionTypeAttrName(result.name),
2762 TypeAttr::get(type));
2763 result.addAttribute(getLinkageAttrName(result.name),
2764 LinkageAttr::get(builder.getContext(), linkage));
2765 result.addAttribute(getCConvAttrName(result.name),
2766 CConvAttr::get(builder.getContext(), cconv));
2767 result.attributes.append(attrs.begin(), attrs.end());
2768 if (dsoLocal)
2769 result.addAttribute(getDsoLocalAttrName(result.name),
2770 builder.getUnitAttr());
2771 if (comdat)
2772 result.addAttribute(getComdatAttrName(result.name), comdat);
2773 if (functionEntryCount)
2774 result.addAttribute(getFunctionEntryCountAttrName(result.name),
2775 builder.getI64IntegerAttr(functionEntryCount.value()));
2776 if (argAttrs.empty())
2777 return;
2778
2779 assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
2780 "expected as many argument attribute lists as arguments");
2781 call_interface_impl::addArgAndResultAttrs(
2782 builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
2783 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2784}
2785
2786// Builds an LLVM function type from the given lists of input and output types.
2787// Returns a null type if any of the types provided are non-LLVM types, or if
2788// there is more than one output type.
2789static Type
2790buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
2791 ArrayRef<Type> outputs,
2792 function_interface_impl::VariadicFlag variadicFlag) {
2793 Builder &b = parser.getBuilder();
2794 if (outputs.size() > 1) {
2795 parser.emitError(loc, message: "failed to construct function type: expected zero or "
2796 "one function result");
2797 return {};
2798 }
2799
2800 // Convert inputs to LLVM types, exit early on error.
2801 SmallVector<Type, 4> llvmInputs;
2802 for (auto t : inputs) {
2803 if (!isCompatibleType(type: t)) {
2804 parser.emitError(loc, message: "failed to construct function type: expected LLVM "
2805 "type for function arguments");
2806 return {};
2807 }
2808 llvmInputs.push_back(Elt: t);
2809 }
2810
2811 // No output is denoted as "void" in LLVM type system.
2812 Type llvmOutput =
2813 outputs.empty() ? LLVMVoidType::get(ctx: b.getContext()) : outputs.front();
2814 if (!isCompatibleType(type: llvmOutput)) {
2815 parser.emitError(loc, message: "failed to construct function type: expected LLVM "
2816 "type for function results")
2817 << llvmOutput;
2818 return {};
2819 }
2820 return LLVMFunctionType::get(llvmOutput, llvmInputs,
2821 variadicFlag.isVariadic());
2822}
2823
2824// Parses an LLVM function.
2825//
2826// operation ::= `llvm.func` linkage? cconv? function-signature
2827// (`comdat(` symbol-ref-id `)`)?
2828// function-attributes?
2829// function-body
2830//
2831ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2832 // Default to external linkage if no keyword is provided.
2833 result.addAttribute(
2834 getLinkageAttrName(result.name),
2835 LinkageAttr::get(parser.getContext(),
2836 parseOptionalLLVMKeyword<Linkage>(
2837 parser, result, LLVM::Linkage::External)));
2838
2839 // Parse optional visibility, default to Default.
2840 result.addAttribute(getVisibility_AttrName(result.name),
2841 parser.getBuilder().getI64IntegerAttr(
2842 parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2843 parser, result, LLVM::Visibility::Default)));
2844
2845 // Parse optional UnnamedAddr, default to None.
2846 result.addAttribute(getUnnamedAddrAttrName(result.name),
2847 parser.getBuilder().getI64IntegerAttr(
2848 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2849 parser, result, LLVM::UnnamedAddr::None)));
2850
2851 // Default to C Calling Convention if no keyword is provided.
2852 result.addAttribute(
2853 getCConvAttrName(result.name),
2854 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
2855 parser, result, LLVM::CConv::C)));
2856
2857 StringAttr nameAttr;
2858 SmallVector<OpAsmParser::Argument> entryArgs;
2859 SmallVector<DictionaryAttr> resultAttrs;
2860 SmallVector<Type> resultTypes;
2861 bool isVariadic;
2862
2863 auto signatureLocation = parser.getCurrentLocation();
2864 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2865 result.attributes) ||
2866 function_interface_impl::parseFunctionSignatureWithArguments(
2867 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
2868 resultAttrs))
2869 return failure();
2870
2871 SmallVector<Type> argTypes;
2872 for (auto &arg : entryArgs)
2873 argTypes.push_back(arg.type);
2874 auto type =
2875 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2876 function_interface_impl::VariadicFlag(isVariadic));
2877 if (!type)
2878 return failure();
2879 result.addAttribute(getFunctionTypeAttrName(result.name),
2880 TypeAttr::get(type));
2881
2882 if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
2883 int64_t minRange, maxRange;
2884 if (parser.parseLParen() || parser.parseInteger(minRange) ||
2885 parser.parseComma() || parser.parseInteger(maxRange) ||
2886 parser.parseRParen())
2887 return failure();
2888 auto intTy = IntegerType::get(parser.getContext(), 32);
2889 result.addAttribute(
2890 getVscaleRangeAttrName(result.name),
2891 LLVM::VScaleRangeAttr::get(parser.getContext(),
2892 IntegerAttr::get(intTy, minRange),
2893 IntegerAttr::get(intTy, maxRange)));
2894 }
2895 // Parse the optional comdat selector.
2896 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2897 SymbolRefAttr comdat;
2898 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2899 parser.parseRParen())
2900 return failure();
2901
2902 result.addAttribute(getComdatAttrName(result.name), comdat);
2903 }
2904
2905 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2906 return failure();
2907 call_interface_impl::addArgAndResultAttrs(
2908 parser.getBuilder(), result, entryArgs, resultAttrs,
2909 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2910
2911 auto *body = result.addRegion();
2912 OptionalParseResult parseResult =
2913 parser.parseOptionalRegion(*body, entryArgs);
2914 return failure(parseResult.has_value() && failed(*parseResult));
2915}
2916
2917// Print the LLVMFuncOp. Collects argument and result types and passes them to
2918// helper functions. Drops "void" result since it cannot be parsed back. Skips
2919// the external linkage since it is the default value.
2920void LLVMFuncOp::print(OpAsmPrinter &p) {
2921 p << ' ';
2922 if (getLinkage() != LLVM::Linkage::External)
2923 p << stringifyLinkage(getLinkage()) << ' ';
2924 StringRef visibility = stringifyVisibility(getVisibility_());
2925 if (!visibility.empty())
2926 p << visibility << ' ';
2927 if (auto unnamedAddr = getUnnamedAddr()) {
2928 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2929 if (!str.empty())
2930 p << str << ' ';
2931 }
2932 if (getCConv() != LLVM::CConv::C)
2933 p << stringifyCConv(getCConv()) << ' ';
2934
2935 p.printSymbolName(getName());
2936
2937 LLVMFunctionType fnType = getFunctionType();
2938 SmallVector<Type, 8> argTypes;
2939 SmallVector<Type, 1> resTypes;
2940 argTypes.reserve(fnType.getNumParams());
2941 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2942 argTypes.push_back(fnType.getParamType(i));
2943
2944 Type returnType = fnType.getReturnType();
2945 if (!llvm::isa<LLVMVoidType>(returnType))
2946 resTypes.push_back(returnType);
2947
2948 function_interface_impl::printFunctionSignature(p, *this, argTypes,
2949 isVarArg(), resTypes);
2950
2951 // Print vscale range if present
2952 if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
2953 p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
2954 << vscale->getMaxRange().getInt() << ')';
2955
2956 // Print the optional comdat selector.
2957 if (auto comdat = getComdat())
2958 p << " comdat(" << *comdat << ')';
2959
2960 function_interface_impl::printFunctionAttributes(
2961 p, *this,
2962 {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2963 getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
2964 getComdatAttrName(), getUnnamedAddrAttrName(),
2965 getVscaleRangeAttrName()});
2966
2967 // Print the body if this is not an external function.
2968 Region &body = getBody();
2969 if (!body.empty()) {
2970 p << ' ';
2971 p.printRegion(body, /*printEntryBlockArgs=*/false,
2972 /*printBlockTerminators=*/true);
2973 }
2974}
2975
2976// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2977// - functions don't have 'common' linkage
2978// - external functions have 'external' or 'extern_weak' linkage;
2979// - vararg is (currently) only supported for external functions;
2980LogicalResult LLVMFuncOp::verify() {
2981 if (getLinkage() == LLVM::Linkage::Common)
2982 return emitOpError() << "functions cannot have '"
2983 << stringifyLinkage(LLVM::Linkage::Common)
2984 << "' linkage";
2985
2986 if (failed(verifyComdat(*this, getComdat())))
2987 return failure();
2988
2989 if (isExternal()) {
2990 if (getLinkage() != LLVM::Linkage::External &&
2991 getLinkage() != LLVM::Linkage::ExternWeak)
2992 return emitOpError() << "external functions must have '"
2993 << stringifyLinkage(LLVM::Linkage::External)
2994 << "' or '"
2995 << stringifyLinkage(LLVM::Linkage::ExternWeak)
2996 << "' linkage";
2997 return success();
2998 }
2999
3000 // In LLVM IR, these attributes are composed by convention, not by design.
3001 if (isNoInline() && isAlwaysInline())
3002 return emitError("no_inline and always_inline attributes are incompatible");
3003
3004 if (isOptimizeNone() && !isNoInline())
3005 return emitOpError("with optimize_none must also be no_inline");
3006
3007 Type landingpadResultTy;
3008 StringRef diagnosticMessage;
3009 bool isLandingpadTypeConsistent =
3010 !walk([&](Operation *op) {
3011 const auto checkType = [&](Type type, StringRef errorMessage) {
3012 if (!landingpadResultTy) {
3013 landingpadResultTy = type;
3014 return WalkResult::advance();
3015 }
3016 if (landingpadResultTy != type) {
3017 diagnosticMessage = errorMessage;
3018 return WalkResult::interrupt();
3019 }
3020 return WalkResult::advance();
3021 };
3022 return TypeSwitch<Operation *, WalkResult>(op)
3023 .Case<LandingpadOp>([&](auto landingpad) {
3024 constexpr StringLiteral errorMessage =
3025 "'llvm.landingpad' should have a consistent result type "
3026 "inside a function";
3027 return checkType(landingpad.getType(), errorMessage);
3028 })
3029 .Case<ResumeOp>([&](auto resume) {
3030 constexpr StringLiteral errorMessage =
3031 "'llvm.resume' should have a consistent input type inside a "
3032 "function";
3033 return checkType(resume.getValue().getType(), errorMessage);
3034 })
3035 .Default([](auto) { return WalkResult::skip(); });
3036 }).wasInterrupted();
3037 if (!isLandingpadTypeConsistent) {
3038 assert(!diagnosticMessage.empty() &&
3039 "Expecting a non-empty diagnostic message");
3040 return emitError(diagnosticMessage);
3041 }
3042
3043 if (failed(verifyBlockTags(*this)))
3044 return failure();
3045
3046 return success();
3047}
3048
3049/// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3050/// - entry block arguments are of LLVM types.
3051LogicalResult LLVMFuncOp::verifyRegions() {
3052 if (isExternal())
3053 return success();
3054
3055 unsigned numArguments = getFunctionType().getNumParams();
3056 Block &entryBlock = front();
3057 for (unsigned i = 0; i < numArguments; ++i) {
3058 Type argType = entryBlock.getArgument(i).getType();
3059 if (!isCompatibleType(argType))
3060 return emitOpError("entry block argument #")
3061 << i << " is not of LLVM type";
3062 }
3063
3064 return success();
3065}
3066
3067Region *LLVMFuncOp::getCallableRegion() {
3068 if (isExternal())
3069 return nullptr;
3070 return &getBody();
3071}
3072
3073//===----------------------------------------------------------------------===//
3074// UndefOp.
3075//===----------------------------------------------------------------------===//
3076
3077/// Fold an undef operation to a dedicated undef attribute.
3078OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) {
3079 return LLVM::UndefAttr::get(getContext());
3080}
3081
3082//===----------------------------------------------------------------------===//
3083// PoisonOp.
3084//===----------------------------------------------------------------------===//
3085
3086/// Fold a poison operation to a dedicated poison attribute.
3087OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) {
3088 return LLVM::PoisonAttr::get(getContext());
3089}
3090
3091//===----------------------------------------------------------------------===//
3092// ZeroOp.
3093//===----------------------------------------------------------------------===//
3094
3095LogicalResult LLVM::ZeroOp::verify() {
3096 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
3097 if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
3098 return emitOpError()
3099 << "target extension type does not support zero-initializer";
3100
3101 return success();
3102}
3103
3104/// Fold a zero operation to a builtin zero attribute when possible and fall
3105/// back to a dedicated zero attribute.
3106OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
3107 OpFoldResult result = Builder(getContext()).getZeroAttr(getType());
3108 if (result)
3109 return result;
3110 return LLVM::ZeroAttr::get(getContext());
3111}
3112
3113//===----------------------------------------------------------------------===//
3114// ConstantOp.
3115//===----------------------------------------------------------------------===//
3116
3117/// Compute the total number of elements in the given type, also taking into
3118/// account nested types. Supported types are `VectorType` and `LLVMArrayType`.
3119/// Everything else is treated as a scalar.
3120static int64_t getNumElements(Type t) {
3121 if (auto vecType = dyn_cast<VectorType>(t)) {
3122 assert(!vecType.isScalable() &&
3123 "number of elements of a scalable vector type is unknown");
3124 return vecType.getNumElements() * getNumElements(vecType.getElementType());
3125 }
3126 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3127 return arrayType.getNumElements() *
3128 getNumElements(arrayType.getElementType());
3129 return 1;
3130}
3131
3132/// Check if the given type is a scalable vector type or a vector/array type
3133/// that contains a nested scalable vector type.
3134static bool hasScalableVectorType(Type t) {
3135 if (auto vecType = dyn_cast<VectorType>(t)) {
3136 if (vecType.isScalable())
3137 return true;
3138 return hasScalableVectorType(vecType.getElementType());
3139 }
3140 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3141 return hasScalableVectorType(arrayType.getElementType());
3142 return false;
3143}
3144
3145/// Verifies the constant array represented by `arrayAttr` matches the provided
3146/// `arrayType`.
3147static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op,
3148 LLVM::LLVMArrayType arrayType,
3149 ArrayAttr arrayAttr, int dim) {
3150 if (arrayType.getNumElements() != arrayAttr.size())
3151 return op.emitOpError()
3152 << "array attribute size does not match array type size in "
3153 "dimension "
3154 << dim << ": " << arrayAttr.size() << " vs. "
3155 << arrayType.getNumElements();
3156
3157 llvm::DenseSet<Attribute> elementsVerified;
3158
3159 // Recursively verify sub-dimensions for multidimensional arrays.
3160 if (auto subArrayType =
3161 dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType())) {
3162 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr))
3163 if (elementsVerified.insert(elementAttr).second) {
3164 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3165 continue;
3166 auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3167 if (!subArrayAttr)
3168 return op.emitOpError()
3169 << "nested attribute for sub-array in dimension " << dim
3170 << " at index " << idx
3171 << " must be a zero, or undef, or array attribute";
3172 if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr,
3173 dim + 1)))
3174 return failure();
3175 }
3176 return success();
3177 }
3178
3179 // Forbid usages of ArrayAttr for simple array types that should use
3180 // DenseElementsAttr instead. Note that there would be a use case for such
3181 // array types when one element value is obtained via a ptr-to-int conversion
3182 // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
3183 // user needs this so far, and it seems better to avoid people misusing the
3184 // ArrayAttr for simple types.
3185 auto structType = dyn_cast<LLVM::LLVMStructType>(arrayType.getElementType());
3186 if (!structType)
3187 return op.emitOpError() << "for array with an array attribute must have a "
3188 "struct element type";
3189
3190 // Shallow verification that leaf attributes are appropriate as struct initial
3191 // value.
3192 size_t numStructElements = structType.getBody().size();
3193 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
3194 if (elementsVerified.insert(elementAttr).second) {
3195 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3196 continue;
3197 auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3198 if (!subArrayAttr)
3199 return op.emitOpError()
3200 << "nested attribute for struct element at index " << idx
3201 << " must be a zero, or undef, or array attribute";
3202 if (subArrayAttr.size() != numStructElements)
3203 return op.emitOpError()
3204 << "nested array attribute size for struct element at index "
3205 << idx << " must match struct size: " << subArrayAttr.size()
3206 << " vs. " << numStructElements;
3207 }
3208 }
3209
3210 return success();
3211}
3212
3213LogicalResult LLVM::ConstantOp::verify() {
3214 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
3215 auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
3216 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
3217 !arrayType.getElementType().isInteger(8)) {
3218 return emitOpError() << "expected array type of "
3219 << sAttr.getValue().size()
3220 << " i8 elements for the string constant";
3221 }
3222 return success();
3223 }
3224 if (auto structType = dyn_cast<LLVMStructType>(getType())) {
3225 auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
3226 if (!arrayAttr) {
3227 return emitOpError() << "expected array attribute for a struct constant";
3228 }
3229
3230 ArrayRef<Type> elementTypes = structType.getBody();
3231 if (arrayAttr.size() != elementTypes.size()) {
3232 return emitOpError() << "expected array attribute of size "
3233 << elementTypes.size();
3234 }
3235 for (auto elementTy : elementTypes) {
3236 if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
3237 return emitOpError() << "expected struct element types to be floating "
3238 "point type or integer type";
3239 }
3240 }
3241
3242 for (size_t i = 0; i < elementTypes.size(); ++i) {
3243 Attribute element = arrayAttr[i];
3244 if (!isa<IntegerAttr, FloatAttr>(element)) {
3245 return emitOpError()
3246 << "expected struct element attribute types to be floating "
3247 "point type or integer type";
3248 }
3249 auto elementType = cast<TypedAttr>(element).getType();
3250 if (elementType != elementTypes[i]) {
3251 return emitOpError()
3252 << "struct element at index " << i << " is of wrong type";
3253 }
3254 }
3255
3256 return success();
3257 }
3258 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
3259 return emitOpError() << "does not support target extension type.";
3260 }
3261
3262 // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3263 if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
3264 if (!llvm::isa<IntegerType>(getType()))
3265 return emitOpError() << "expected integer type";
3266 } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
3267 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
3268 unsigned floatWidth = APFloat::getSizeInBits(sem);
3269 if (auto floatTy = dyn_cast<FloatType>(getType())) {
3270 if (floatTy.getWidth() != floatWidth) {
3271 return emitOpError() << "expected float type of width " << floatWidth;
3272 }
3273 }
3274 // See the comment for getLLVMConstant for more details about why 8-bit
3275 // floats can be represented by integers.
3276 if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
3277 return emitOpError() << "expected integer type of width " << floatWidth;
3278 }
3279 } else if (isa<ElementsAttr>(getValue())) {
3280 if (hasScalableVectorType(getType())) {
3281 // The exact number of elements of a scalable vector is unknown, so we
3282 // allow only splat attributes.
3283 auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
3284 if (!splatElementsAttr)
3285 return emitOpError()
3286 << "scalable vector type requires a splat attribute";
3287 return success();
3288 }
3289 if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
3290 return emitOpError() << "expected vector or array type";
3291 // The number of elements of the attribute and the type must match.
3292 if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
3293 int64_t attrNumElements = elementsAttr.getNumElements();
3294 if (getNumElements(getType()) != attrNumElements)
3295 return emitOpError()
3296 << "type and attribute have a different number of elements: "
3297 << getNumElements(getType()) << " vs. " << attrNumElements;
3298 }
3299 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
3300 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
3301 if (!arrayType)
3302 return emitOpError() << "expected array type";
3303 // When the attribute is an ArrayAttr, check that its nesting matches the
3304 // corresponding ArrayType or VectorType nesting.
3305 return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
3306 } else {
3307 return emitOpError()
3308 << "only supports integer, float, string or elements attributes";
3309 }
3310
3311 return success();
3312}
3313
3314bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
3315 // The value's type must be the same as the provided type.
3316 auto typedAttr = dyn_cast<TypedAttr>(value);
3317 if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
3318 return false;
3319 // The value's type must be an LLVM compatible type.
3320 if (!isCompatibleType(type))
3321 return false;
3322 // TODO: Add support for additional attributes kinds once needed.
3323 return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
3324}
3325
3326ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
3327 Type type, Location loc) {
3328 if (isBuildableWith(value, type))
3329 return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value));
3330 return nullptr;
3331}
3332
3333// Constant op constant-folds to its value.
3334OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
3335
3336//===----------------------------------------------------------------------===//
3337// AtomicRMWOp
3338//===----------------------------------------------------------------------===//
3339
3340void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
3341 AtomicBinOp binOp, Value ptr, Value val,
3342 AtomicOrdering ordering, StringRef syncscope,
3343 unsigned alignment, bool isVolatile) {
3344 build(builder, state, val.getType(), binOp, ptr, val, ordering,
3345 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3346 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
3347 /*access_groups=*/nullptr,
3348 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3349}
3350
3351LogicalResult AtomicRMWOp::verify() {
3352 auto valType = getVal().getType();
3353 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
3354 getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax ||
3355 getBinOp() == AtomicBinOp::fminimum ||
3356 getBinOp() == AtomicBinOp::fmaximum) {
3357 if (isCompatibleVectorType(valType)) {
3358 if (isScalableVectorType(valType))
3359 return emitOpError("expected LLVM IR fixed vector type");
3360 Type elemType = llvm::cast<VectorType>(valType).getElementType();
3361 if (!isCompatibleFloatingPointType(elemType))
3362 return emitOpError(
3363 "expected LLVM IR floating point type for vector element");
3364 } else if (!isCompatibleFloatingPointType(valType)) {
3365 return emitOpError("expected LLVM IR floating point type");
3366 }
3367 } else if (getBinOp() == AtomicBinOp::xchg) {
3368 DataLayout dataLayout = DataLayout::closest(*this);
3369 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3370 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
3371 } else {
3372 auto intType = llvm::dyn_cast<IntegerType>(valType);
3373 unsigned intBitWidth = intType ? intType.getWidth() : 0;
3374 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
3375 intBitWidth != 64)
3376 return emitOpError("expected LLVM IR integer type");
3377 }
3378
3379 if (static_cast<unsigned>(getOrdering()) <
3380 static_cast<unsigned>(AtomicOrdering::monotonic))
3381 return emitOpError() << "expected at least '"
3382 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
3383 << "' ordering";
3384
3385 return success();
3386}
3387
3388//===----------------------------------------------------------------------===//
3389// AtomicCmpXchgOp
3390//===----------------------------------------------------------------------===//
3391
3392/// Returns an LLVM struct type that contains a value type and a boolean type.
3393static LLVMStructType getValAndBoolStructType(Type valType) {
3394 auto boolType = IntegerType::get(valType.getContext(), 1);
3395 return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
3396}
3397
3398void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
3399 Value ptr, Value cmp, Value val,
3400 AtomicOrdering successOrdering,
3401 AtomicOrdering failureOrdering, StringRef syncscope,
3402 unsigned alignment, bool isWeak, bool isVolatile) {
3403 build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
3404 successOrdering, failureOrdering,
3405 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3406 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
3407 isVolatile, /*access_groups=*/nullptr,
3408 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3409}
3410
3411LogicalResult AtomicCmpXchgOp::verify() {
3412 auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
3413 if (!ptrType)
3414 return emitOpError("expected LLVM IR pointer type for operand #0");
3415 auto valType = getVal().getType();
3416 DataLayout dataLayout = DataLayout::closest(*this);
3417 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3418 return emitOpError("unexpected LLVM IR type");
3419 if (getSuccessOrdering() < AtomicOrdering::monotonic ||
3420 getFailureOrdering() < AtomicOrdering::monotonic)
3421 return emitOpError("ordering must be at least 'monotonic'");
3422 if (getFailureOrdering() == AtomicOrdering::release ||
3423 getFailureOrdering() == AtomicOrdering::acq_rel)
3424 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
3425 return success();
3426}
3427
3428//===----------------------------------------------------------------------===//
3429// FenceOp
3430//===----------------------------------------------------------------------===//
3431
3432void FenceOp::build(OpBuilder &builder, OperationState &state,
3433 AtomicOrdering ordering, StringRef syncscope) {
3434 build(builder, state, ordering,
3435 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
3436}
3437
3438LogicalResult FenceOp::verify() {
3439 if (getOrdering() == AtomicOrdering::not_atomic ||
3440 getOrdering() == AtomicOrdering::unordered ||
3441 getOrdering() == AtomicOrdering::monotonic)
3442 return emitOpError("can be given only acquire, release, acq_rel, "
3443 "and seq_cst orderings");
3444 return success();
3445}
3446
3447//===----------------------------------------------------------------------===//
3448// Verifier for extension ops
3449//===----------------------------------------------------------------------===//
3450
3451/// Verifies that the given extension operation operates on consistent scalars
3452/// or vectors, and that the target width is larger than the input width.
3453template <class ExtOp>
3454static LogicalResult verifyExtOp(ExtOp op) {
3455 IntegerType inputType, outputType;
3456 if (isCompatibleVectorType(op.getArg().getType())) {
3457 if (!isCompatibleVectorType(op.getResult().getType()))
3458 return op.emitError(
3459 "input type is a vector but output type is an integer");
3460 if (getVectorNumElements(op.getArg().getType()) !=
3461 getVectorNumElements(op.getResult().getType()))
3462 return op.emitError("input and output vectors are of incompatible shape");
3463 // Because this is a CastOp, the element of vectors is guaranteed to be an
3464 // integer.
3465 inputType = cast<IntegerType>(
3466 cast<VectorType>(op.getArg().getType()).getElementType());
3467 outputType = cast<IntegerType>(
3468 cast<VectorType>(op.getResult().getType()).getElementType());
3469 } else {
3470 // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
3471 // an integer.
3472 inputType = cast<IntegerType>(op.getArg().getType());
3473 outputType = dyn_cast<IntegerType>(op.getResult().getType());
3474 if (!outputType)
3475 return op.emitError(
3476 "input type is an integer but output type is a vector");
3477 }
3478
3479 if (outputType.getWidth() <= inputType.getWidth())
3480 return op.emitError("integer width of the output type is smaller or "
3481 "equal to the integer width of the input type");
3482 return success();
3483}
3484
3485//===----------------------------------------------------------------------===//
3486// ZExtOp
3487//===----------------------------------------------------------------------===//
3488
3489LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
3490
3491OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
3492 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3493 if (!arg)
3494 return {};
3495
3496 size_t targetSize = cast<IntegerType>(getType()).getWidth();
3497 return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
3498}
3499
3500//===----------------------------------------------------------------------===//
3501// SExtOp
3502//===----------------------------------------------------------------------===//
3503
3504LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
3505
3506//===----------------------------------------------------------------------===//
3507// Folder and verifier for LLVM::BitcastOp
3508//===----------------------------------------------------------------------===//
3509
3510/// Folds a cast op that can be chained.
3511template <typename T>
3512static OpFoldResult foldChainableCast(T castOp,
3513 typename T::FoldAdaptor adaptor) {
3514 // cast(x : T0, T0) -> x
3515 if (castOp.getArg().getType() == castOp.getType())
3516 return castOp.getArg();
3517 if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
3518 // cast(cast(x : T0, T1), T0) -> x
3519 if (prev.getArg().getType() == castOp.getType())
3520 return prev.getArg();
3521 // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
3522 castOp.getArgMutable().set(prev.getArg());
3523 return Value{castOp};
3524 }
3525 return {};
3526}
3527
3528OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
3529 return foldChainableCast(*this, adaptor);
3530}
3531
3532LogicalResult LLVM::BitcastOp::verify() {
3533 auto resultType = llvm::dyn_cast<LLVMPointerType>(
3534 extractVectorElementType(getResult().getType()));
3535 auto sourceType = llvm::dyn_cast<LLVMPointerType>(
3536 extractVectorElementType(getArg().getType()));
3537
3538 // If one of the types is a pointer (or vector of pointers), then
3539 // both source and result type have to be pointers.
3540 if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
3541 return emitOpError("can only cast pointers from and to pointers");
3542
3543 if (!resultType)
3544 return success();
3545
3546 auto isVector = llvm::IsaPred<VectorType>;
3547
3548 // Due to bitcast requiring both operands to be of the same size, it is not
3549 // possible for only one of the two to be a pointer of vectors.
3550 if (isVector(getResult().getType()) && !isVector(getArg().getType()))
3551 return emitOpError("cannot cast pointer to vector of pointers");
3552
3553 if (!isVector(getResult().getType()) && isVector(getArg().getType()))
3554 return emitOpError("cannot cast vector of pointers to pointer");
3555
3556 // Bitcast cannot cast between pointers of different address spaces.
3557 // 'llvm.addrspacecast' must be used for this purpose instead.
3558 if (resultType.getAddressSpace() != sourceType.getAddressSpace())
3559 return emitOpError("cannot cast pointers of different address spaces, "
3560 "use 'llvm.addrspacecast' instead");
3561
3562 return success();
3563}
3564
3565//===----------------------------------------------------------------------===//
3566// Folder for LLVM::AddrSpaceCastOp
3567//===----------------------------------------------------------------------===//
3568
3569OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
3570 return foldChainableCast(*this, adaptor);
3571}
3572
3573Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); }
3574
3575//===----------------------------------------------------------------------===//
3576// Folder for LLVM::GEPOp
3577//===----------------------------------------------------------------------===//
3578
3579OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
3580 GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
3581 adaptor.getDynamicIndices());
3582
3583 // gep %x:T, 0 -> %x
3584 if (getBase().getType() == getType() && indices.size() == 1)
3585 if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
3586 if (integer.getValue().isZero())
3587 return getBase();
3588
3589 // Canonicalize any dynamic indices of constant value to constant indices.
3590 bool changed = false;
3591 SmallVector<GEPArg> gepArgs;
3592 for (auto iter : llvm::enumerate(indices)) {
3593 auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
3594 // Constant indices can only be int32_t, so if integer does not fit we
3595 // are forced to keep it dynamic, despite being a constant.
3596 if (!indices.isDynamicIndex(iter.index()) || !integer ||
3597 !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
3598
3599 PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
3600 if (Value val = llvm::dyn_cast_if_present<Value>(existing))
3601 gepArgs.emplace_back(val);
3602 else
3603 gepArgs.emplace_back(cast<IntegerAttr>(existing).getInt());
3604
3605 continue;
3606 }
3607
3608 changed = true;
3609 gepArgs.emplace_back(integer.getInt());
3610 }
3611 if (changed) {
3612 SmallVector<int32_t> rawConstantIndices;
3613 SmallVector<Value> dynamicIndices;
3614 destructureIndices(getElemType(), gepArgs, rawConstantIndices,
3615 dynamicIndices);
3616
3617 getDynamicIndicesMutable().assign(dynamicIndices);
3618 setRawConstantIndices(rawConstantIndices);
3619 return Value{*this};
3620 }
3621
3622 return {};
3623}
3624
3625Value LLVM::GEPOp::getViewSource() { return getBase(); }
3626
3627//===----------------------------------------------------------------------===//
3628// ShlOp
3629//===----------------------------------------------------------------------===//
3630
3631OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
3632 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3633 if (!rhs)
3634 return {};
3635
3636 if (rhs.getValue().getZExtValue() >=
3637 getLhs().getType().getIntOrFloatBitWidth())
3638 return {}; // TODO: Fold into poison.
3639
3640 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3641 if (!lhs)
3642 return {};
3643
3644 return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
3645}
3646
3647//===----------------------------------------------------------------------===//
3648// OrOp
3649//===----------------------------------------------------------------------===//
3650
3651OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
3652 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3653 if (!lhs)
3654 return {};
3655
3656 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3657 if (!rhs)
3658 return {};
3659
3660 return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
3661}
3662
3663//===----------------------------------------------------------------------===//
3664// CallIntrinsicOp
3665//===----------------------------------------------------------------------===//
3666
3667LogicalResult CallIntrinsicOp::verify() {
3668 if (!getIntrin().starts_with("llvm."))
3669 return emitOpError() << "intrinsic name must start with 'llvm.'";
3670 if (failed(verifyOperandBundles(*this)))
3671 return failure();
3672 return success();
3673}
3674
3675void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3676 mlir::StringAttr intrin, mlir::ValueRange args) {
3677 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
3678 FastmathFlagsAttr{},
3679 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3680 /*res_attrs=*/{});
3681}
3682
3683void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3684 mlir::StringAttr intrin, mlir::ValueRange args,
3685 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3686 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
3687 fastMathFlags,
3688 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3689 /*res_attrs=*/{});
3690}
3691
3692void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3693 mlir::Type resultType, mlir::StringAttr intrin,
3694 mlir::ValueRange args) {
3695 build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
3696 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3697 /*res_attrs=*/{});
3698}
3699
3700void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3701 mlir::TypeRange resultTypes,
3702 mlir::StringAttr intrin, mlir::ValueRange args,
3703 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3704 build(builder, state, resultTypes, intrin, args, fastMathFlags,
3705 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3706 /*res_attrs=*/{});
3707}
3708
3709ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
3710 OperationState &result) {
3711 StringAttr intrinAttr;
3712 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
3713 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands;
3714 SmallVector<SmallVector<Type>> opBundleOperandTypes;
3715 ArrayAttr opBundleTags;
3716
3717 // Parse intrinsic name.
3718 if (parser.parseCustomAttributeWithFallback(
3719 intrinAttr, parser.getBuilder().getType<NoneType>()))
3720 return failure();
3721 result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name),
3722 intrinAttr);
3723
3724 if (parser.parseLParen())
3725 return failure();
3726
3727 // Parse the function arguments.
3728 if (parser.parseOperandList(operands))
3729 return mlir::failure();
3730
3731 if (parser.parseRParen())
3732 return mlir::failure();
3733
3734 // Handle bundles.
3735 SMLoc opBundlesLoc = parser.getCurrentLocation();
3736 if (std::optional<ParseResult> result = parseOpBundles(
3737 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
3738 result && failed(*result))
3739 return failure();
3740 if (opBundleTags && !opBundleTags.empty())
3741 result.addAttribute(
3742 CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(),
3743 opBundleTags);
3744
3745 if (parser.parseOptionalAttrDict(result.attributes))
3746 return mlir::failure();
3747
3748 SmallVector<DictionaryAttr> argAttrs;
3749 SmallVector<DictionaryAttr> resultAttrs;
3750 if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
3751 operands, argAttrs, resultAttrs))
3752 return failure();
3753 call_interface_impl::addArgAndResultAttrs(
3754 parser.getBuilder(), result, argAttrs, resultAttrs,
3755 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3756
3757 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
3758 opBundleOperandTypes,
3759 getOpBundleSizesAttrName(result.name)))
3760 return failure();
3761
3762 int32_t numOpBundleOperands = 0;
3763 for (const auto &operands : opBundleOperands)
3764 numOpBundleOperands += operands.size();
3765
3766 result.addAttribute(
3767 CallIntrinsicOp::getOperandSegmentSizeAttr(),
3768 parser.getBuilder().getDenseI32ArrayAttr(
3769 {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
3770
3771 return mlir::success();
3772}
3773
3774void CallIntrinsicOp::print(OpAsmPrinter &p) {
3775 p << ' ';
3776 p.printAttributeWithoutType(getIntrinAttr());
3777
3778 OperandRange args = getArgs();
3779 p << "(" << args << ")";
3780
3781 // Operand bundles.
3782 if (!getOpBundleOperands().empty()) {
3783 p << ' ';
3784 printOpBundles(p, *this, getOpBundleOperands(),
3785 getOpBundleOperands().getTypes(), getOpBundleTagsAttr());
3786 }
3787
3788 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
3789 {getOperandSegmentSizesAttrName(),
3790 getOpBundleSizesAttrName(), getIntrinAttrName(),
3791 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
3792 getResAttrsAttrName()});
3793
3794 p << " : ";
3795
3796 // Reconstruct the MLIR function type from operand and result types.
3797 call_interface_impl::printFunctionSignature(
3798 p, args.getTypes(), getArgAttrsAttr(),
3799 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
3800}
3801
3802//===----------------------------------------------------------------------===//
3803// LinkerOptionsOp
3804//===----------------------------------------------------------------------===//
3805
3806LogicalResult LinkerOptionsOp::verify() {
3807 if (mlir::Operation *parentOp = (*this)->getParentOp();
3808 parentOp && !satisfiesLLVMModule(parentOp))
3809 return emitOpError("must appear at the module level");
3810 return success();
3811}
3812
3813//===----------------------------------------------------------------------===//
3814// ModuleFlagsOp
3815//===----------------------------------------------------------------------===//
3816
3817LogicalResult ModuleFlagsOp::verify() {
3818 if (Operation *parentOp = (*this)->getParentOp();
3819 parentOp && !satisfiesLLVMModule(parentOp))
3820 return emitOpError("must appear at the module level");
3821 for (Attribute flag : getFlags())
3822 if (!isa<ModuleFlagAttr>(flag))
3823 return emitOpError("expected a module flag attribute");
3824 return success();
3825}
3826
3827//===----------------------------------------------------------------------===//
3828// InlineAsmOp
3829//===----------------------------------------------------------------------===//
3830
3831void InlineAsmOp::getEffects(
3832 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3833 &effects) {
3834 if (getHasSideEffects()) {
3835 effects.emplace_back(MemoryEffects::Write::get());
3836 effects.emplace_back(MemoryEffects::Read::get());
3837 }
3838}
3839
3840//===----------------------------------------------------------------------===//
3841// BlockAddressOp
3842//===----------------------------------------------------------------------===//
3843
3844LogicalResult
3845BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3846 Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
3847 getBlockAddr().getFunction());
3848 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
3849
3850 if (!function)
3851 return emitOpError("must reference a function defined by 'llvm.func'");
3852
3853 return success();
3854}
3855
3856LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
3857 return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
3858 parentLLVMModule(*this), getBlockAddr().getFunction()));
3859}
3860
3861BlockTagOp BlockAddressOp::getBlockTagOp() {
3862 auto funcOp = dyn_cast<LLVMFuncOp>(mlir::SymbolTable::lookupNearestSymbolFrom(
3863 parentLLVMModule(*this), getBlockAddr().getFunction()));
3864 if (!funcOp)
3865 return nullptr;
3866
3867 BlockTagOp blockTagOp = nullptr;
3868 funcOp.walk([&](LLVM::BlockTagOp labelOp) {
3869 if (labelOp.getTag() == getBlockAddr().getTag()) {
3870 blockTagOp = labelOp;
3871 return WalkResult::interrupt();
3872 }
3873 return WalkResult::advance();
3874 });
3875 return blockTagOp;
3876}
3877
3878LogicalResult BlockAddressOp::verify() {
3879 if (!getBlockTagOp())
3880 return emitOpError(
3881 "expects an existing block label target in the referenced function");
3882
3883 return success();
3884}
3885
3886/// Fold a blockaddress operation to a dedicated blockaddress
3887/// attribute.
3888OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
3889
3890//===----------------------------------------------------------------------===//
3891// LLVM::IndirectBrOp
3892//===----------------------------------------------------------------------===//
3893
3894SuccessorOperands IndirectBrOp::getSuccessorOperands(unsigned index) {
3895 assert(index < getNumSuccessors() && "invalid successor index");
3896 return SuccessorOperands(getSuccOperandsMutable()[index]);
3897}
3898
3899void IndirectBrOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3900 Value addr, ArrayRef<ValueRange> succOperands,
3901 BlockRange successors) {
3902 odsState.addOperands(addr);
3903 for (ValueRange range : succOperands)
3904 odsState.addOperands(range);
3905 SmallVector<int32_t> rangeSegments;
3906 for (ValueRange range : succOperands)
3907 rangeSegments.push_back(range.size());
3908 odsState.getOrAddProperties<Properties>().indbr_operand_segments =
3909 odsBuilder.getDenseI32ArrayAttr(rangeSegments);
3910 odsState.addSuccessors(successors);
3911}
3912
3913static ParseResult parseIndirectBrOpSucessors(
3914 OpAsmParser &parser, Type &flagType,
3915 SmallVectorImpl<Block *> &succOperandBlocks,
3916 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
3917 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
3918 if (failed(Result: parser.parseCommaSeparatedList(
3919 delimiter: OpAsmParser::Delimiter::Square,
3920 parseElementFn: [&]() {
3921 Block *destination = nullptr;
3922 SmallVector<OpAsmParser::UnresolvedOperand> operands;
3923 SmallVector<Type> operandTypes;
3924
3925 if (parser.parseSuccessor(dest&: destination).failed())
3926 return failure();
3927
3928 if (succeeded(Result: parser.parseOptionalLParen())) {
3929 if (failed(Result: parser.parseOperandList(
3930 result&: operands, delimiter: OpAsmParser::Delimiter::None)) ||
3931 failed(Result: parser.parseColonTypeList(result&: operandTypes)) ||
3932 failed(Result: parser.parseRParen()))
3933 return failure();
3934 }
3935 succOperandBlocks.push_back(Elt: destination);
3936 succOperands.emplace_back(Args&: operands);
3937 succOperandsTypes.emplace_back(Args&: operandTypes);
3938 return success();
3939 },
3940 contextMessage: "successor blocks")))
3941 return failure();
3942 return success();
3943}
3944
3945static void
3946printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType,
3947 SuccessorRange succs, OperandRangeRange succOperands,
3948 const TypeRangeRange &succOperandsTypes) {
3949 p << "[";
3950 llvm::interleave(
3951 c: llvm::zip(t&: succs, u&: succOperands),
3952 each_fn: [&](auto i) {
3953 p.printNewline();
3954 p.printSuccessorAndUseList(successor: std::get<0>(i), succOperands: std::get<1>(i));
3955 },
3956 between_fn: [&] { p << ','; });
3957 if (!succOperands.empty())
3958 p.printNewline();
3959 p << "]";
3960}
3961
3962//===----------------------------------------------------------------------===//
3963// AssumeOp (intrinsic)
3964//===----------------------------------------------------------------------===//
3965
3966void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3967 mlir::Value cond) {
3968 return build(builder, state, cond, /*op_bundle_operands=*/{},
3969 /*op_bundle_tags=*/ArrayAttr{});
3970}
3971
3972void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3973 Value cond,
3974 ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) {
3975 SmallVector<ValueRange> opBundleOperands;
3976 SmallVector<Attribute> opBundleTags;
3977 opBundleOperands.reserve(opBundles.size());
3978 opBundleTags.reserve(opBundles.size());
3979
3980 for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) {
3981 opBundleOperands.emplace_back(bundle.inputs());
3982 opBundleTags.push_back(
3983 StringAttr::get(builder.getContext(), bundle.getTag()));
3984 }
3985
3986 auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
3987 return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
3988}
3989
3990void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3991 Value cond, llvm::StringRef tag, ValueRange args) {
3992 llvm::OperandBundleDefT<Value> opBundle(
3993 tag.str(), SmallVector<Value>(args.begin(), args.end()));
3994 return build(builder, state, cond, opBundle);
3995}
3996
3997void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3998 Value cond, AssumeAlignTag, Value ptr, Value align) {
3999 return build(builder, state, cond, "align", ValueRange{ptr, align});
4000}
4001
4002void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4003 Value cond, AssumeSeparateStorageTag, Value ptr1,
4004 Value ptr2) {
4005 return build(builder, state, cond, "separate_storage",
4006 ValueRange{ptr1, ptr2});
4007}
4008
4009LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
4010
4011//===----------------------------------------------------------------------===//
4012// masked_gather (intrinsic)
4013//===----------------------------------------------------------------------===//
4014
4015LogicalResult LLVM::masked_gather::verify() {
4016 auto ptrsVectorType = getPtrs().getType();
4017 Type expectedPtrsVectorType =
4018 LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
4019 LLVM::getVectorNumElements(getRes().getType()));
4020 // Vector of pointers type should match result vector type, other than the
4021 // element type.
4022 if (ptrsVectorType != expectedPtrsVectorType)
4023 return emitOpError("expected operand #1 type to be ")
4024 << expectedPtrsVectorType;
4025 return success();
4026}
4027
4028//===----------------------------------------------------------------------===//
4029// masked_scatter (intrinsic)
4030//===----------------------------------------------------------------------===//
4031
4032LogicalResult LLVM::masked_scatter::verify() {
4033 auto ptrsVectorType = getPtrs().getType();
4034 Type expectedPtrsVectorType =
4035 LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
4036 LLVM::getVectorNumElements(getValue().getType()));
4037 // Vector of pointers type should match value vector type, other than the
4038 // element type.
4039 if (ptrsVectorType != expectedPtrsVectorType)
4040 return emitOpError("expected operand #2 type to be ")
4041 << expectedPtrsVectorType;
4042 return success();
4043}
4044
4045//===----------------------------------------------------------------------===//
4046// InlineAsmOp
4047//===----------------------------------------------------------------------===//
4048
4049LogicalResult InlineAsmOp::verify() {
4050 if (!getTailCallKindAttr())
4051 return success();
4052
4053 if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
4054 return emitOpError(
4055 "tail call kind 'musttail' is not supported by this operation");
4056
4057 return success();
4058}
4059
4060//===----------------------------------------------------------------------===//
4061// LLVMDialect initialization, type parsing, and registration.
4062//===----------------------------------------------------------------------===//
4063
4064void LLVMDialect::initialize() {
4065 registerAttributes();
4066
4067 // clang-format off
4068 addTypes<LLVMVoidType,
4069 LLVMTokenType,
4070 LLVMLabelType,
4071 LLVMMetadataType>();
4072 // clang-format on
4073 registerTypes();
4074
4075 addOperations<
4076#define GET_OP_LIST
4077#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4078 ,
4079#define GET_OP_LIST
4080#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4081 >();
4082
4083 // Support unknown operations because not all LLVM operations are registered.
4084 allowUnknownOperations();
4085 declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
4086}
4087
4088#define GET_OP_CLASSES
4089#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4090
4091#define GET_OP_CLASSES
4092#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4093
4094LogicalResult LLVMDialect::verifyDataLayoutString(
4095 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
4096 llvm::Expected<llvm::DataLayout> maybeDataLayout =
4097 llvm::DataLayout::parse(descr);
4098 if (maybeDataLayout)
4099 return success();
4100
4101 std::string message;
4102 llvm::raw_string_ostream messageStream(message);
4103 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
4104 reportError("invalid data layout descriptor: " + message);
4105 return failure();
4106}
4107
4108/// Verify LLVM dialect attributes.
4109LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
4110 NamedAttribute attr) {
4111 // If the data layout attribute is present, it must use the LLVM data layout
4112 // syntax. Try parsing it and report errors in case of failure. Users of this
4113 // attribute may assume it is well-formed and can pass it to the (asserting)
4114 // llvm::DataLayout constructor.
4115 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
4116 return success();
4117 if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
4118 return verifyDataLayoutString(
4119 stringAttr.getValue(),
4120 [op](const Twine &message) { op->emitOpError() << message.str(); });
4121
4122 return op->emitOpError() << "expected '"
4123 << LLVM::LLVMDialect::getDataLayoutAttrName()
4124 << "' to be a string attributes";
4125}
4126
4127LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
4128 Type paramType,
4129 NamedAttribute paramAttr) {
4130 // LLVM attribute may be attached to a result of operation that has not been
4131 // converted to LLVM dialect yet, so the result may have a type with unknown
4132 // representation in LLVM dialect type space. In this case we cannot verify
4133 // whether the attribute may be
4134 bool verifyValueType = isCompatibleType(paramType);
4135 StringAttr name = paramAttr.getName();
4136
4137 auto checkUnitAttrType = [&]() -> LogicalResult {
4138 if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
4139 return op->emitError() << name << " should be a unit attribute";
4140 return success();
4141 };
4142 auto checkTypeAttrType = [&]() -> LogicalResult {
4143 if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
4144 return op->emitError() << name << " should be a type attribute";
4145 return success();
4146 };
4147 auto checkIntegerAttrType = [&]() -> LogicalResult {
4148 if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
4149 return op->emitError() << name << " should be an integer attribute";
4150 return success();
4151 };
4152 auto checkPointerType = [&]() -> LogicalResult {
4153 if (!llvm::isa<LLVMPointerType>(paramType))
4154 return op->emitError()
4155 << name << " attribute attached to non-pointer LLVM type";
4156 return success();
4157 };
4158 auto checkIntegerType = [&]() -> LogicalResult {
4159 if (!llvm::isa<IntegerType>(paramType))
4160 return op->emitError()
4161 << name << " attribute attached to non-integer LLVM type";
4162 return success();
4163 };
4164 auto checkPointerTypeMatches = [&]() -> LogicalResult {
4165 if (failed(checkPointerType()))
4166 return failure();
4167
4168 return success();
4169 };
4170
4171 // Check a unit attribute that is attached to a pointer value.
4172 if (name == LLVMDialect::getNoAliasAttrName() ||
4173 name == LLVMDialect::getReadonlyAttrName() ||
4174 name == LLVMDialect::getReadnoneAttrName() ||
4175 name == LLVMDialect::getWriteOnlyAttrName() ||
4176 name == LLVMDialect::getNestAttrName() ||
4177 name == LLVMDialect::getNoCaptureAttrName() ||
4178 name == LLVMDialect::getNoFreeAttrName() ||
4179 name == LLVMDialect::getNonNullAttrName()) {
4180 if (failed(checkUnitAttrType()))
4181 return failure();
4182 if (verifyValueType && failed(checkPointerType()))
4183 return failure();
4184 return success();
4185 }
4186
4187 // Check a type attribute that is attached to a pointer value.
4188 if (name == LLVMDialect::getStructRetAttrName() ||
4189 name == LLVMDialect::getByValAttrName() ||
4190 name == LLVMDialect::getByRefAttrName() ||
4191 name == LLVMDialect::getElementTypeAttrName() ||
4192 name == LLVMDialect::getInAllocaAttrName() ||
4193 name == LLVMDialect::getPreallocatedAttrName()) {
4194 if (failed(checkTypeAttrType()))
4195 return failure();
4196 if (verifyValueType && failed(checkPointerTypeMatches()))
4197 return failure();
4198 return success();
4199 }
4200
4201 // Check a unit attribute that is attached to an integer value.
4202 if (name == LLVMDialect::getSExtAttrName() ||
4203 name == LLVMDialect::getZExtAttrName()) {
4204 if (failed(checkUnitAttrType()))
4205 return failure();
4206 if (verifyValueType && failed(checkIntegerType()))
4207 return failure();
4208 return success();
4209 }
4210
4211 // Check an integer attribute that is attached to a pointer value.
4212 if (name == LLVMDialect::getAlignAttrName() ||
4213 name == LLVMDialect::getDereferenceableAttrName() ||
4214 name == LLVMDialect::getDereferenceableOrNullAttrName()) {
4215 if (failed(checkIntegerAttrType()))
4216 return failure();
4217 if (verifyValueType && failed(checkPointerType()))
4218 return failure();
4219 return success();
4220 }
4221
4222 // Check an integer attribute that is attached to a pointer value.
4223 if (name == LLVMDialect::getStackAlignmentAttrName()) {
4224 if (failed(checkIntegerAttrType()))
4225 return failure();
4226 return success();
4227 }
4228
4229 // Check a unit attribute that can be attached to arbitrary types.
4230 if (name == LLVMDialect::getNoUndefAttrName() ||
4231 name == LLVMDialect::getInRegAttrName() ||
4232 name == LLVMDialect::getReturnedAttrName())
4233 return checkUnitAttrType();
4234
4235 return success();
4236}
4237
4238/// Verify LLVMIR function argument attributes.
4239LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
4240 unsigned regionIdx,
4241 unsigned argIdx,
4242 NamedAttribute argAttr) {
4243 auto funcOp = dyn_cast<FunctionOpInterface>(op);
4244 if (!funcOp)
4245 return success();
4246 Type argType = funcOp.getArgumentTypes()[argIdx];
4247
4248 return verifyParameterAttribute(op, argType, argAttr);
4249}
4250
4251LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
4252 unsigned regionIdx,
4253 unsigned resIdx,
4254 NamedAttribute resAttr) {
4255 auto funcOp = dyn_cast<FunctionOpInterface>(op);
4256 if (!funcOp)
4257 return success();
4258 Type resType = funcOp.getResultTypes()[resIdx];
4259
4260 // Check to see if this function has a void return with a result attribute
4261 // to it. It isn't clear what semantics we would assign to that.
4262 if (llvm::isa<LLVMVoidType>(resType))
4263 return op->emitError() << "cannot attach result attributes to functions "
4264 "with a void return";
4265
4266 // Check to see if this attribute is allowed as a result attribute. Only
4267 // explicitly forbidden LLVM attributes will cause an error.
4268 auto name = resAttr.getName();
4269 if (name == LLVMDialect::getAllocAlignAttrName() ||
4270 name == LLVMDialect::getAllocatedPointerAttrName() ||
4271 name == LLVMDialect::getByValAttrName() ||
4272 name == LLVMDialect::getByRefAttrName() ||
4273 name == LLVMDialect::getInAllocaAttrName() ||
4274 name == LLVMDialect::getNestAttrName() ||
4275 name == LLVMDialect::getNoCaptureAttrName() ||
4276 name == LLVMDialect::getNoFreeAttrName() ||
4277 name == LLVMDialect::getPreallocatedAttrName() ||
4278 name == LLVMDialect::getReadnoneAttrName() ||
4279 name == LLVMDialect::getReadonlyAttrName() ||
4280 name == LLVMDialect::getReturnedAttrName() ||
4281 name == LLVMDialect::getStackAlignmentAttrName() ||
4282 name == LLVMDialect::getStructRetAttrName() ||
4283 name == LLVMDialect::getWriteOnlyAttrName())
4284 return op->emitError() << name << " is not a valid result attribute";
4285 return verifyParameterAttribute(op, resType, resAttr);
4286}
4287
4288Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
4289 Type type, Location loc) {
4290 // If this was folded from an operation other than llvm.mlir.constant, it
4291 // should be materialized as such. Note that an llvm.mlir.zero may fold into
4292 // a builtin zero attribute and thus will materialize as a llvm.mlir.constant.
4293 if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value))
4294 if (isa<LLVM::LLVMPointerType>(type))
4295 return builder.create<LLVM::AddressOfOp>(loc, type, symbol);
4296 if (isa<LLVM::UndefAttr>(value))
4297 return builder.create<LLVM::UndefOp>(loc, type);
4298 if (isa<LLVM::PoisonAttr>(value))
4299 return builder.create<LLVM::PoisonOp>(loc, type);
4300 if (isa<LLVM::ZeroAttr>(value))
4301 return builder.create<LLVM::ZeroOp>(loc, type);
4302 // Otherwise try materializing it as a regular llvm.mlir.constant op.
4303 return LLVM::ConstantOp::materialize(builder, value, type, loc);
4304}
4305
4306//===----------------------------------------------------------------------===//
4307// Utility functions.
4308//===----------------------------------------------------------------------===//
4309
4310Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
4311 StringRef name, StringRef value,
4312 LLVM::Linkage linkage) {
4313 assert(builder.getInsertionBlock() &&
4314 builder.getInsertionBlock()->getParentOp() &&
4315 "expected builder to point to a block constrained in an op");
4316 auto module =
4317 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
4318 assert(module && "builder points to an op outside of a module");
4319
4320 // Create the global at the entry of the module.
4321 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
4322 MLIRContext *ctx = builder.getContext();
4323 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
4324 auto global = moduleBuilder.create<LLVM::GlobalOp>(
4325 loc, type, /*isConstant=*/true, linkage, name,
4326 builder.getStringAttr(value), /*alignment=*/0);
4327
4328 LLVMPointerType ptrType = LLVMPointerType::get(ctx);
4329 // Get the pointer to the first character in the global string.
4330 Value globalPtr =
4331 builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr());
4332 return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr,
4333 ArrayRef<GEPArg>{0, 0});
4334}
4335
4336bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
4337 return op->hasTrait<OpTrait::SymbolTable>() &&
4338 op->hasTrait<OpTrait::IsIsolatedFromAbove>();
4339}
4340

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp