1//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the LLVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "LLVMInlining.h"
16#include "TypeDetail.h"
17#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
18#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
19#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/Interfaces/FunctionImplementation.h"
27
28#include "llvm/ADT/SCCIterator.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/AsmParser/Parser.h"
31#include "llvm/Bitcode/BitcodeReader.h"
32#include "llvm/Bitcode/BitcodeWriter.h"
33#include "llvm/IR/Attributes.h"
34#include "llvm/IR/Function.h"
35#include "llvm/IR/Type.h"
36#include "llvm/Support/Error.h"
37#include "llvm/Support/Mutex.h"
38#include "llvm/Support/SourceMgr.h"
39
40#include <numeric>
41#include <optional>
42
43using namespace mlir;
44using namespace mlir::LLVM;
45using mlir::LLVM::cconv::getMaxEnumValForCConv;
46using mlir::LLVM::linkage::getMaxEnumValForLinkage;
47
48#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
49
50//===----------------------------------------------------------------------===//
51// Property Helpers
52//===----------------------------------------------------------------------===//
53
54//===----------------------------------------------------------------------===//
55// IntegerOverflowFlags
56
57namespace mlir {
58static Attribute convertToAttribute(MLIRContext *ctx,
59 IntegerOverflowFlags flags) {
60 return IntegerOverflowFlagsAttr::get(ctx, flags);
61}
62
63static LogicalResult
64convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
65 function_ref<InFlightDiagnostic()> emitError) {
66 auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
67 if (!flagsAttr) {
68 return emitError() << "expected 'overflowFlags' attribute to be an "
69 "IntegerOverflowFlagsAttr, but got "
70 << attr;
71 }
72 flags = flagsAttr.getValue();
73 return success();
74}
75} // namespace mlir
76
77static ParseResult parseOverflowFlags(AsmParser &p,
78 IntegerOverflowFlags &flags) {
79 if (failed(result: p.parseOptionalKeyword(keyword: "overflow"))) {
80 flags = IntegerOverflowFlags::none;
81 return success();
82 }
83 if (p.parseLess())
84 return failure();
85 do {
86 StringRef kw;
87 SMLoc loc = p.getCurrentLocation();
88 if (p.parseKeyword(keyword: &kw))
89 return failure();
90 std::optional<IntegerOverflowFlags> flag =
91 symbolizeIntegerOverflowFlags(kw);
92 if (!flag)
93 return p.emitError(loc,
94 message: "invalid overflow flag: expected nsw, nuw, or none");
95 flags = flags | *flag;
96 } while (succeeded(result: p.parseOptionalComma()));
97 return p.parseGreater();
98}
99
100static void printOverflowFlags(AsmPrinter &p, Operation *op,
101 IntegerOverflowFlags flags) {
102 if (flags == IntegerOverflowFlags::none)
103 return;
104 p << " overflow<";
105 SmallVector<StringRef, 2> strs;
106 if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
107 strs.push_back(Elt: "nsw");
108 if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
109 strs.push_back(Elt: "nuw");
110 llvm::interleaveComma(c: strs, os&: p);
111 p << ">";
112}
113
114//===----------------------------------------------------------------------===//
115// Attribute Helpers
116//===----------------------------------------------------------------------===//
117
118static constexpr const char kElemTypeAttrName[] = "elem_type";
119
120static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
121 SmallVector<NamedAttribute, 8> filteredAttrs(
122 llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) {
123 if (attr.getName() == "fastmathFlags") {
124 auto defAttr =
125 FastmathFlagsAttr::get(attr.getValue().getContext(), {});
126 return defAttr != attr.getValue();
127 }
128 return true;
129 }));
130 return filteredAttrs;
131}
132
133static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
134 NamedAttrList &result) {
135 return parser.parseOptionalAttrDict(result);
136}
137
138static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
139 DictionaryAttr attrs) {
140 auto filteredAttrs = processFMFAttr(attrs.getValue());
141 if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
142 printer.printOptionalAttrDict(
143 attrs: filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
144 } else {
145 printer.printOptionalAttrDict(attrs: filteredAttrs);
146 }
147}
148
149/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
150/// fully defined llvm.func.
151static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
152 Operation *op,
153 SymbolTableCollection &symbolTable) {
154 StringRef name = symbol.getValue();
155 auto func =
156 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
157 if (!func)
158 return op->emitOpError(message: "'")
159 << name << "' does not reference a valid LLVM function";
160 if (func.isExternal())
161 return op->emitOpError(message: "'") << name << "' does not have a definition";
162 return success();
163}
164
165/// Returns a boolean type that has the same shape as `type`. It supports both
166/// fixed size vectors as well as scalable vectors.
167static Type getI1SameShape(Type type) {
168 Type i1Type = IntegerType::get(type.getContext(), 1);
169 if (LLVM::isCompatibleVectorType(type))
170 return LLVM::getVectorType(elementType: i1Type, numElements: LLVM::getVectorNumElements(type));
171 return i1Type;
172}
173
174// Parses one of the keywords provided in the list `keywords` and returns the
175// position of the parsed keyword in the list. If none of the keywords from the
176// list is parsed, returns -1.
177static int parseOptionalKeywordAlternative(OpAsmParser &parser,
178 ArrayRef<StringRef> keywords) {
179 for (const auto &en : llvm::enumerate(First&: keywords)) {
180 if (succeeded(result: parser.parseOptionalKeyword(keyword: en.value())))
181 return en.index();
182 }
183 return -1;
184}
185
186namespace {
187template <typename Ty>
188struct EnumTraits {};
189
190#define REGISTER_ENUM_TYPE(Ty) \
191 template <> \
192 struct EnumTraits<Ty> { \
193 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
194 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
195 }
196
197REGISTER_ENUM_TYPE(Linkage);
198REGISTER_ENUM_TYPE(UnnamedAddr);
199REGISTER_ENUM_TYPE(CConv);
200REGISTER_ENUM_TYPE(Visibility);
201} // namespace
202
203/// Parse an enum from the keyword, or default to the provided default value.
204/// The return type is the enum type by default, unless overridden with the
205/// second template argument.
206template <typename EnumTy, typename RetTy = EnumTy>
207static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
208 OperationState &result,
209 EnumTy defaultValue) {
210 SmallVector<StringRef, 10> names;
211 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
212 names.push_back(Elt: EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
213
214 int index = parseOptionalKeywordAlternative(parser, keywords: names);
215 if (index == -1)
216 return static_cast<RetTy>(defaultValue);
217 return static_cast<RetTy>(index);
218}
219
220//===----------------------------------------------------------------------===//
221// Printing, parsing, folding and builder for LLVM::CmpOp.
222//===----------------------------------------------------------------------===//
223
224void ICmpOp::print(OpAsmPrinter &p) {
225 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
226 << ", " << getOperand(1);
227 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
228 p << " : " << getLhs().getType();
229}
230
231void FCmpOp::print(OpAsmPrinter &p) {
232 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
233 << ", " << getOperand(1);
234 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
235 p << " : " << getLhs().getType();
236}
237
238// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
239// attribute-dict? `:` type
240// <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
241// attribute-dict? `:` type
242template <typename CmpPredicateType>
243static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
244 StringAttr predicateAttr;
245 OpAsmParser::UnresolvedOperand lhs, rhs;
246 Type type;
247 SMLoc predicateLoc, trailingTypeLoc;
248 if (parser.getCurrentLocation(loc: &predicateLoc) ||
249 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
250 parser.parseOperand(result&: lhs) || parser.parseComma() ||
251 parser.parseOperand(result&: rhs) ||
252 parser.parseOptionalAttrDict(result&: result.attributes) || parser.parseColon() ||
253 parser.getCurrentLocation(loc: &trailingTypeLoc) || parser.parseType(result&: type) ||
254 parser.resolveOperand(operand: lhs, type, result&: result.operands) ||
255 parser.resolveOperand(operand: rhs, type, result&: result.operands))
256 return failure();
257
258 // Replace the string attribute `predicate` with an integer attribute.
259 int64_t predicateValue = 0;
260 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
261 std::optional<ICmpPredicate> predicate =
262 symbolizeICmpPredicate(predicateAttr.getValue());
263 if (!predicate)
264 return parser.emitError(loc: predicateLoc)
265 << "'" << predicateAttr.getValue()
266 << "' is an incorrect value of the 'predicate' attribute";
267 predicateValue = static_cast<int64_t>(*predicate);
268 } else {
269 std::optional<FCmpPredicate> predicate =
270 symbolizeFCmpPredicate(predicateAttr.getValue());
271 if (!predicate)
272 return parser.emitError(loc: predicateLoc)
273 << "'" << predicateAttr.getValue()
274 << "' is an incorrect value of the 'predicate' attribute";
275 predicateValue = static_cast<int64_t>(*predicate);
276 }
277
278 result.attributes.set("predicate",
279 parser.getBuilder().getI64IntegerAttr(predicateValue));
280
281 // The result type is either i1 or a vector type <? x i1> if the inputs are
282 // vectors.
283 if (!isCompatibleType(type))
284 return parser.emitError(loc: trailingTypeLoc,
285 message: "expected LLVM dialect-compatible type");
286 result.addTypes(newTypes: getI1SameShape(type));
287 return success();
288}
289
290ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
291 return parseCmpOp<ICmpPredicate>(parser, result);
292}
293
294ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
295 return parseCmpOp<FCmpPredicate>(parser, result);
296}
297
298/// Returns a scalar or vector boolean attribute of the given type.
299static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
300 auto boolAttr = BoolAttr::get(context: ctx, value);
301 ShapedType shapedType = dyn_cast<ShapedType>(type);
302 if (!shapedType)
303 return boolAttr;
304 return DenseElementsAttr::get(shapedType, boolAttr);
305}
306
307OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
308 if (getPredicate() != ICmpPredicate::eq &&
309 getPredicate() != ICmpPredicate::ne)
310 return {};
311
312 // cmpi(eq/ne, x, x) -> true/false
313 if (getLhs() == getRhs())
314 return getBoolAttribute(getType(), getContext(),
315 getPredicate() == ICmpPredicate::eq);
316
317 // cmpi(eq/ne, alloca, null) -> false/true
318 if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
319 return getBoolAttribute(getType(), getContext(),
320 getPredicate() == ICmpPredicate::ne);
321
322 // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
323 if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
324 Value lhs = getLhs();
325 Value rhs = getRhs();
326 getLhsMutable().assign(rhs);
327 getRhsMutable().assign(lhs);
328 return getResult();
329 }
330
331 return {};
332}
333
334//===----------------------------------------------------------------------===//
335// Printing, parsing and verification for LLVM::AllocaOp.
336//===----------------------------------------------------------------------===//
337
338void AllocaOp::print(OpAsmPrinter &p) {
339 auto funcTy =
340 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
341
342 if (getInalloca())
343 p << " inalloca";
344
345 p << ' ' << getArraySize() << " x " << getElemType();
346 if (getAlignment() && *getAlignment() != 0)
347 p.printOptionalAttrDict((*this)->getAttrs(),
348 {kElemTypeAttrName, getInallocaAttrName()});
349 else
350 p.printOptionalAttrDict(
351 (*this)->getAttrs(),
352 {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
353 p << " : " << funcTy;
354}
355
356// <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
357// attribute-dict? `:` type `,` type
358ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
359 OpAsmParser::UnresolvedOperand arraySize;
360 Type type, elemType;
361 SMLoc trailingTypeLoc;
362
363 if (succeeded(parser.parseOptionalKeyword("inalloca")))
364 result.addAttribute(getInallocaAttrName(result.name),
365 UnitAttr::get(parser.getContext()));
366
367 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
368 parser.parseType(elemType) ||
369 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
370 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
371 return failure();
372
373 std::optional<NamedAttribute> alignmentAttr =
374 result.attributes.getNamed("alignment");
375 if (alignmentAttr.has_value()) {
376 auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
377 if (!alignmentInt)
378 return parser.emitError(parser.getNameLoc(),
379 "expected integer alignment");
380 if (alignmentInt.getValue().isZero())
381 result.attributes.erase("alignment");
382 }
383
384 // Extract the result type from the trailing function type.
385 auto funcType = llvm::dyn_cast<FunctionType>(type);
386 if (!funcType || funcType.getNumInputs() != 1 ||
387 funcType.getNumResults() != 1)
388 return parser.emitError(
389 trailingTypeLoc,
390 "expected trailing function type with one argument and one result");
391
392 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
393 return failure();
394
395 Type resultType = funcType.getResult(0);
396 if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
397 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
398
399 result.addTypes({funcType.getResult(0)});
400 return success();
401}
402
403LogicalResult AllocaOp::verify() {
404 // Only certain target extension types can be used in 'alloca'.
405 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
406 targetExtType && !targetExtType.supportsMemOps())
407 return emitOpError()
408 << "this target extension type cannot be used in alloca";
409
410 return success();
411}
412
413Type AllocaOp::getResultPtrElementType() { return getElemType(); }
414
415//===----------------------------------------------------------------------===//
416// LLVM::BrOp
417//===----------------------------------------------------------------------===//
418
419SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
420 assert(index == 0 && "invalid successor index");
421 return SuccessorOperands(getDestOperandsMutable());
422}
423
424//===----------------------------------------------------------------------===//
425// LLVM::CondBrOp
426//===----------------------------------------------------------------------===//
427
428SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
429 assert(index < getNumSuccessors() && "invalid successor index");
430 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
431 : getFalseDestOperandsMutable());
432}
433
434void CondBrOp::build(OpBuilder &builder, OperationState &result,
435 Value condition, Block *trueDest, ValueRange trueOperands,
436 Block *falseDest, ValueRange falseOperands,
437 std::optional<std::pair<uint32_t, uint32_t>> weights) {
438 DenseI32ArrayAttr weightsAttr;
439 if (weights)
440 weightsAttr =
441 builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
442 static_cast<int32_t>(weights->second)});
443
444 build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
445 /*loop_annotation=*/{}, trueDest, falseDest);
446}
447
448//===----------------------------------------------------------------------===//
449// LLVM::SwitchOp
450//===----------------------------------------------------------------------===//
451
452void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
453 Block *defaultDestination, ValueRange defaultOperands,
454 DenseIntElementsAttr caseValues,
455 BlockRange caseDestinations,
456 ArrayRef<ValueRange> caseOperands,
457 ArrayRef<int32_t> branchWeights) {
458 DenseI32ArrayAttr weightsAttr;
459 if (!branchWeights.empty())
460 weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
461
462 build(builder, result, value, defaultOperands, caseOperands, caseValues,
463 weightsAttr, defaultDestination, caseDestinations);
464}
465
466void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
467 Block *defaultDestination, ValueRange defaultOperands,
468 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
469 ArrayRef<ValueRange> caseOperands,
470 ArrayRef<int32_t> branchWeights) {
471 DenseIntElementsAttr caseValuesAttr;
472 if (!caseValues.empty()) {
473 ShapedType caseValueType = VectorType::get(
474 static_cast<int64_t>(caseValues.size()), value.getType());
475 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
476 }
477
478 build(builder, result, value, defaultDestination, defaultOperands,
479 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
480}
481
482void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
483 Block *defaultDestination, ValueRange defaultOperands,
484 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
485 ArrayRef<ValueRange> caseOperands,
486 ArrayRef<int32_t> branchWeights) {
487 DenseIntElementsAttr caseValuesAttr;
488 if (!caseValues.empty()) {
489 ShapedType caseValueType = VectorType::get(
490 static_cast<int64_t>(caseValues.size()), value.getType());
491 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
492 }
493
494 build(builder, result, value, defaultDestination, defaultOperands,
495 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
496}
497
498/// <cases> ::= `[` (case (`,` case )* )? `]`
499/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
500static ParseResult parseSwitchOpCases(
501 OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
502 SmallVectorImpl<Block *> &caseDestinations,
503 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
504 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
505 if (failed(result: parser.parseLSquare()))
506 return failure();
507 if (succeeded(result: parser.parseOptionalRSquare()))
508 return success();
509 SmallVector<APInt> values;
510 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
511 auto parseCase = [&]() {
512 int64_t value = 0;
513 if (failed(result: parser.parseInteger(result&: value)))
514 return failure();
515 values.push_back(Elt: APInt(bitWidth, value));
516
517 Block *destination;
518 SmallVector<OpAsmParser::UnresolvedOperand> operands;
519 SmallVector<Type> operandTypes;
520 if (parser.parseColon() || parser.parseSuccessor(dest&: destination))
521 return failure();
522 if (!parser.parseOptionalLParen()) {
523 if (parser.parseOperandList(result&: operands, delimiter: OpAsmParser::Delimiter::None,
524 /*allowResultNumber=*/false) ||
525 parser.parseColonTypeList(result&: operandTypes) || parser.parseRParen())
526 return failure();
527 }
528 caseDestinations.push_back(Elt: destination);
529 caseOperands.emplace_back(Args&: operands);
530 caseOperandTypes.emplace_back(Args&: operandTypes);
531 return success();
532 };
533 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: parseCase)))
534 return failure();
535
536 ShapedType caseValueType =
537 VectorType::get(static_cast<int64_t>(values.size()), flagType);
538 caseValues = DenseIntElementsAttr::get(caseValueType, values);
539 return parser.parseRSquare();
540}
541
542static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
543 DenseIntElementsAttr caseValues,
544 SuccessorRange caseDestinations,
545 OperandRangeRange caseOperands,
546 const TypeRangeRange &caseOperandTypes) {
547 p << '[';
548 p.printNewline();
549 if (!caseValues) {
550 p << ']';
551 return;
552 }
553
554 size_t index = 0;
555 llvm::interleave(
556 c: llvm::zip(t&: caseValues, u&: caseDestinations),
557 each_fn: [&](auto i) {
558 p << " ";
559 p << std::get<0>(i).getLimitedValue();
560 p << ": ";
561 p.printSuccessorAndUseList(successor: std::get<1>(i), succOperands: caseOperands[index++]);
562 },
563 between_fn: [&] {
564 p << ',';
565 p.printNewline();
566 });
567 p.printNewline();
568 p << ']';
569}
570
571LogicalResult SwitchOp::verify() {
572 if ((!getCaseValues() && !getCaseDestinations().empty()) ||
573 (getCaseValues() &&
574 getCaseValues()->size() !=
575 static_cast<int64_t>(getCaseDestinations().size())))
576 return emitOpError("expects number of case values to match number of "
577 "case destinations");
578 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
579 return emitError("expects number of branch weights to match number of "
580 "successors: ")
581 << getBranchWeights()->size() << " vs " << getNumSuccessors();
582 if (getCaseValues() &&
583 getValue().getType() != getCaseValues()->getElementType())
584 return emitError("expects case value type to match condition value type");
585 return success();
586}
587
588SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
589 assert(index < getNumSuccessors() && "invalid successor index");
590 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
591 : getCaseOperandsMutable(index - 1));
592}
593
594//===----------------------------------------------------------------------===//
595// Code for LLVM::GEPOp.
596//===----------------------------------------------------------------------===//
597
598constexpr int32_t GEPOp::kDynamicIndex;
599
600GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
601 return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
602 getDynamicIndices());
603}
604
605/// Returns the elemental type of any LLVM-compatible vector type or self.
606static Type extractVectorElementType(Type type) {
607 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
608 return vectorType.getElementType();
609 if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type))
610 return scalableVectorType.getElementType();
611 if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type))
612 return fixedVectorType.getElementType();
613 return type;
614}
615
616/// Destructures the 'indices' parameter into 'rawConstantIndices' and
617/// 'dynamicIndices', encoding the former in the process. In the process,
618/// dynamic indices which are used to index into a structure type are converted
619/// to constant indices when possible. To do this, the GEPs element type should
620/// be passed as first parameter.
621static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
622 SmallVectorImpl<int32_t> &rawConstantIndices,
623 SmallVectorImpl<Value> &dynamicIndices) {
624 for (const GEPArg &iter : indices) {
625 // If the thing we are currently indexing into is a struct we must turn
626 // any integer constants into constant indices. If this is not possible
627 // we don't do anything here. The verifier will catch it and emit a proper
628 // error. All other canonicalization is done in the fold method.
629 bool requiresConst = !rawConstantIndices.empty() &&
630 isa_and_nonnull<LLVMStructType>(Val: currType);
631 if (Value val = llvm::dyn_cast_if_present<Value>(Val: iter)) {
632 APInt intC;
633 if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
634 intC.isSignedIntN(kGEPConstantBitWidth)) {
635 rawConstantIndices.push_back(Elt: intC.getSExtValue());
636 } else {
637 rawConstantIndices.push_back(GEPOp::kDynamicIndex);
638 dynamicIndices.push_back(Elt: val);
639 }
640 } else {
641 rawConstantIndices.push_back(Elt: iter.get<GEPConstantIndex>());
642 }
643
644 // Skip for very first iteration of this loop. First index does not index
645 // within the aggregates, but is just a pointer offset.
646 if (rawConstantIndices.size() == 1 || !currType)
647 continue;
648
649 currType =
650 TypeSwitch<Type, Type>(currType)
651 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
652 LLVMArrayType>([](auto containerType) {
653 return containerType.getElementType();
654 })
655 .Case([&](LLVMStructType structType) -> Type {
656 int64_t memberIndex = rawConstantIndices.back();
657 if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
658 structType.getBody().size())
659 return structType.getBody()[memberIndex];
660 return nullptr;
661 })
662 .Default(Type(nullptr));
663 }
664}
665
666void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
667 Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
668 bool inbounds, ArrayRef<NamedAttribute> attributes) {
669 SmallVector<int32_t> rawConstantIndices;
670 SmallVector<Value> dynamicIndices;
671 destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
672
673 result.addTypes(resultType);
674 result.addAttributes(attributes);
675 result.addAttribute(getRawConstantIndicesAttrName(result.name),
676 builder.getDenseI32ArrayAttr(rawConstantIndices));
677 if (inbounds) {
678 result.addAttribute(getInboundsAttrName(result.name),
679 builder.getUnitAttr());
680 }
681 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
682 result.addOperands(basePtr);
683 result.addOperands(dynamicIndices);
684}
685
686void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
687 Type elementType, Value basePtr, ValueRange indices,
688 bool inbounds, ArrayRef<NamedAttribute> attributes) {
689 build(builder, result, resultType, elementType, basePtr,
690 SmallVector<GEPArg>(indices), inbounds, attributes);
691}
692
693static ParseResult
694parseGEPIndices(OpAsmParser &parser,
695 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
696 DenseI32ArrayAttr &rawConstantIndices) {
697 SmallVector<int32_t> constantIndices;
698
699 auto idxParser = [&]() -> ParseResult {
700 int32_t constantIndex;
701 OptionalParseResult parsedInteger =
702 parser.parseOptionalInteger(result&: constantIndex);
703 if (parsedInteger.has_value()) {
704 if (failed(result: parsedInteger.value()))
705 return failure();
706 constantIndices.push_back(Elt: constantIndex);
707 return success();
708 }
709
710 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
711 return parser.parseOperand(result&: indices.emplace_back());
712 };
713 if (parser.parseCommaSeparatedList(parseElementFn: idxParser))
714 return failure();
715
716 rawConstantIndices =
717 DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
718 return success();
719}
720
721static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
722 OperandRange indices,
723 DenseI32ArrayAttr rawConstantIndices) {
724 llvm::interleaveComma(
725 c: GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), os&: printer,
726 each_fn: [&](PointerUnion<IntegerAttr, Value> cst) {
727 if (Value val = llvm::dyn_cast_if_present<Value>(Val&: cst))
728 printer.printOperand(value: val);
729 else
730 printer << cst.get<IntegerAttr>().getInt();
731 });
732}
733
734/// For the given `indices`, check if they comply with `baseGEPType`,
735/// especially check against LLVMStructTypes nested within.
736static LogicalResult
737verifyStructIndices(Type baseGEPType, unsigned indexPos,
738 GEPIndicesAdaptor<ValueRange> indices,
739 function_ref<InFlightDiagnostic()> emitOpError) {
740 if (indexPos >= indices.size())
741 // Stop searching
742 return success();
743
744 return TypeSwitch<Type, LogicalResult>(baseGEPType)
745 .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
746 if (!indices[indexPos].is<IntegerAttr>())
747 return emitOpError() << "expected index " << indexPos
748 << " indexing a struct to be constant";
749
750 int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
751 ArrayRef<Type> elementTypes = structType.getBody();
752 if (gepIndex < 0 ||
753 static_cast<size_t>(gepIndex) >= elementTypes.size())
754 return emitOpError() << "index " << indexPos
755 << " indexing a struct is out of bounds";
756
757 // Instead of recursively going into every children types, we only
758 // dive into the one indexed by gepIndex.
759 return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
760 indices, emitOpError);
761 })
762 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
763 LLVMArrayType>([&](auto containerType) -> LogicalResult {
764 return verifyStructIndices(containerType.getElementType(), indexPos + 1,
765 indices, emitOpError);
766 })
767 .Default([&](auto otherType) -> LogicalResult {
768 return emitOpError()
769 << "type " << otherType << " cannot be indexed (index #"
770 << indexPos << ")";
771 });
772}
773
774/// Driver function around `verifyStructIndices`.
775static LogicalResult
776verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
777 function_ref<InFlightDiagnostic()> emitOpError) {
778 return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
779}
780
781LogicalResult LLVM::GEPOp::verify() {
782 if (static_cast<size_t>(
783 llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
784 getDynamicIndices().size())
785 return emitOpError("expected as many dynamic indices as specified in '")
786 << getRawConstantIndicesAttrName().getValue() << "'";
787
788 return verifyStructIndices(getElemType(), getIndices(),
789 [&] { return emitOpError(); });
790}
791
792Type GEPOp::getResultPtrElementType() {
793 // Set the initial type currently being used for indexing. This will be
794 // updated as the indices get walked over.
795 Type selectedType = getElemType();
796
797 // Follow the indexed elements in the gep.
798 auto indices = getIndices();
799 for (GEPIndicesAdaptor<ValueRange>::value_type index :
800 llvm::drop_begin(indices)) {
801 // GEPs can only index into aggregates which can be structs or arrays.
802
803 // The resulting type if indexing into an array type is always the element
804 // type, regardless of index.
805 if (auto arrayType = dyn_cast<LLVMArrayType>(selectedType)) {
806 selectedType = arrayType.getElementType();
807 continue;
808 }
809
810 // The GEP verifier ensures that any index into structs are static and
811 // that they refer to a field within the struct.
812 selectedType = cast<DestructurableTypeInterface>(selectedType)
813 .getTypeAtIndex(cast<IntegerAttr>(index));
814 }
815
816 // When there are no more indices, the type currently being used for indexing
817 // is the type of the value pointed at by the returned indexed pointer.
818 return selectedType;
819}
820
821//===----------------------------------------------------------------------===//
822// LoadOp
823//===----------------------------------------------------------------------===//
824
825void LoadOp::getEffects(
826 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
827 &effects) {
828 effects.emplace_back(MemoryEffects::Read::get(), getAddr());
829 // Volatile operations can have target-specific read-write effects on
830 // memory besides the one referred to by the pointer operand.
831 // Similarly, atomic operations that are monotonic or stricter cause
832 // synchronization that from a language point-of-view, are arbitrary
833 // read-writes into memory.
834 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
835 getOrdering() != AtomicOrdering::unordered)) {
836 effects.emplace_back(MemoryEffects::Write::get());
837 effects.emplace_back(MemoryEffects::Read::get());
838 }
839}
840
841/// Returns true if the given type is supported by atomic operations. All
842/// integer and float types with limited bit width are supported. Additionally,
843/// depending on the operation pointers may be supported as well.
844static bool isTypeCompatibleWithAtomicOp(Type type, bool isPointerTypeAllowed) {
845 if (llvm::isa<LLVMPointerType>(type))
846 return isPointerTypeAllowed;
847
848 std::optional<unsigned> bitWidth;
849 if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) {
850 if (!isCompatibleFloatingPointType(type))
851 return false;
852 bitWidth = floatType.getWidth();
853 }
854 if (auto integerType = llvm::dyn_cast<IntegerType>(type))
855 bitWidth = integerType.getWidth();
856 // The type is neither an integer, float, or pointer type.
857 if (!bitWidth)
858 return false;
859 return *bitWidth == 8 || *bitWidth == 16 || *bitWidth == 32 ||
860 *bitWidth == 64;
861}
862
863/// Verifies the attributes and the type of atomic memory access operations.
864template <typename OpTy>
865LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
866 ArrayRef<AtomicOrdering> unsupportedOrderings) {
867 if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
868 if (!isTypeCompatibleWithAtomicOp(type: valueType,
869 /*isPointerTypeAllowed=*/true))
870 return memOp.emitOpError("unsupported type ")
871 << valueType << " for atomic access";
872 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
873 return memOp.emitOpError("unsupported ordering '")
874 << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
875 if (!memOp.getAlignment())
876 return memOp.emitOpError("expected alignment for atomic access");
877 return success();
878 }
879 if (memOp.getSyncscope())
880 return memOp.emitOpError(
881 "expected syncscope to be null for non-atomic access");
882 return success();
883}
884
885LogicalResult LoadOp::verify() {
886 Type valueType = getResult().getType();
887 return verifyAtomicMemOp(*this, valueType,
888 {AtomicOrdering::release, AtomicOrdering::acq_rel});
889}
890
891void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
892 Value addr, unsigned alignment, bool isVolatile,
893 bool isNonTemporal, bool isInvariant,
894 AtomicOrdering ordering, StringRef syncscope) {
895 build(builder, state, type, addr,
896 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
897 isNonTemporal, isInvariant, ordering,
898 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
899 /*access_groups=*/nullptr,
900 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
901 /*tbaa=*/nullptr);
902}
903
904//===----------------------------------------------------------------------===//
905// StoreOp
906//===----------------------------------------------------------------------===//
907
908void StoreOp::getEffects(
909 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
910 &effects) {
911 effects.emplace_back(MemoryEffects::Write::get(), getAddr());
912 // Volatile operations can have target-specific read-write effects on
913 // memory besides the one referred to by the pointer operand.
914 // Similarly, atomic operations that are monotonic or stricter cause
915 // synchronization that from a language point-of-view, are arbitrary
916 // read-writes into memory.
917 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
918 getOrdering() != AtomicOrdering::unordered)) {
919 effects.emplace_back(MemoryEffects::Write::get());
920 effects.emplace_back(MemoryEffects::Read::get());
921 }
922}
923
924LogicalResult StoreOp::verify() {
925 Type valueType = getValue().getType();
926 return verifyAtomicMemOp(*this, valueType,
927 {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
928}
929
930void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
931 Value addr, unsigned alignment, bool isVolatile,
932 bool isNonTemporal, AtomicOrdering ordering,
933 StringRef syncscope) {
934 build(builder, state, value, addr,
935 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
936 isNonTemporal, ordering,
937 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
938 /*access_groups=*/nullptr,
939 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
940}
941
942//===----------------------------------------------------------------------===//
943// CallOp
944//===----------------------------------------------------------------------===//
945
946/// Gets the MLIR Op-like result types of a LLVMFunctionType.
947static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
948 SmallVector<Type, 1> results;
949 Type resultType = calleeType.getReturnType();
950 if (!isa<LLVM::LLVMVoidType>(Val: resultType))
951 results.push_back(Elt: resultType);
952 return results;
953}
954
955/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
956static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
957 ValueRange args) {
958 Type resultType;
959 if (results.empty())
960 resultType = LLVMVoidType::get(ctx: context);
961 else
962 resultType = results.front();
963 return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
964 /*isVarArg=*/false);
965}
966
967void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
968 StringRef callee, ValueRange args) {
969 build(builder, state, results, builder.getStringAttr(callee), args);
970}
971
972void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
973 StringAttr callee, ValueRange args) {
974 build(builder, state, results, SymbolRefAttr::get(callee), args);
975}
976
977void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
978 FlatSymbolRefAttr callee, ValueRange args) {
979 assert(callee && "expected non-null callee in direct call builder");
980 build(builder, state, results,
981 TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
982 callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
983 /*CConv=*/nullptr,
984 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
985 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
986}
987
988void CallOp::build(OpBuilder &builder, OperationState &state,
989 LLVMFunctionType calleeType, StringRef callee,
990 ValueRange args) {
991 build(builder, state, calleeType, builder.getStringAttr(callee), args);
992}
993
994void CallOp::build(OpBuilder &builder, OperationState &state,
995 LLVMFunctionType calleeType, StringAttr callee,
996 ValueRange args) {
997 build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
998}
999
1000void CallOp::build(OpBuilder &builder, OperationState &state,
1001 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1002 ValueRange args) {
1003 build(builder, state, getCallOpResultTypes(calleeType),
1004 TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
1005 /*branch_weights=*/nullptr, /*CConv=*/nullptr,
1006 /*access_groups=*/nullptr,
1007 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
1008}
1009
1010void CallOp::build(OpBuilder &builder, OperationState &state,
1011 LLVMFunctionType calleeType, ValueRange args) {
1012 build(builder, state, getCallOpResultTypes(calleeType),
1013 TypeAttr::get(calleeType), /*callee=*/nullptr, args,
1014 /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1015 /*CConv=*/nullptr,
1016 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1017 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
1018}
1019
1020void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1021 ValueRange args) {
1022 auto calleeType = func.getFunctionType();
1023 build(builder, state, getCallOpResultTypes(calleeType),
1024 TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
1025 /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1026 /*CConv=*/nullptr,
1027 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1028 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
1029}
1030
1031CallInterfaceCallable CallOp::getCallableForCallee() {
1032 // Direct call.
1033 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1034 return calleeAttr;
1035 // Indirect call, callee Value is the first operand.
1036 return getOperand(0);
1037}
1038
1039void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1040 // Direct call.
1041 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1042 auto symRef = callee.get<SymbolRefAttr>();
1043 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1044 }
1045 // Indirect call, callee Value is the first operand.
1046 return setOperand(0, callee.get<Value>());
1047}
1048
1049Operation::operand_range CallOp::getArgOperands() {
1050 return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
1051}
1052
1053MutableOperandRange CallOp::getArgOperandsMutable() {
1054 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1055 getCalleeOperands().size());
1056}
1057
1058/// Verify that an inlinable callsite of a debug-info-bearing function in a
1059/// debug-info-bearing function has a debug location attached to it. This
1060/// mirrors an LLVM IR verifier.
1061static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1062 if (callee.isExternal())
1063 return success();
1064 auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
1065 if (!parentFunc)
1066 return success();
1067
1068 auto hasSubprogram = [](Operation *op) {
1069 return op->getLoc()
1070 ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1071 nullptr;
1072 };
1073 if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1074 return success();
1075 bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
1076 if (!containsLoc)
1077 return callOp.emitError()
1078 << "inlinable function call in a function with a DISubprogram "
1079 "location must have a debug location";
1080 return success();
1081}
1082
1083LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1084 if (getNumResults() > 1)
1085 return emitOpError("must have 0 or 1 result");
1086
1087 // Type for the callee, we'll get it differently depending if it is a direct
1088 // or indirect call.
1089 Type fnType;
1090
1091 bool isIndirect = false;
1092
1093 // If this is an indirect call, the callee attribute is missing.
1094 FlatSymbolRefAttr calleeName = getCalleeAttr();
1095 if (!calleeName) {
1096 isIndirect = true;
1097 if (!getNumOperands())
1098 return emitOpError(
1099 "must have either a `callee` attribute or at least an operand");
1100 auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
1101 if (!ptrType)
1102 return emitOpError("indirect call expects a pointer as callee: ")
1103 << getOperand(0).getType();
1104
1105 return success();
1106 } else {
1107 Operation *callee =
1108 symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
1109 if (!callee)
1110 return emitOpError()
1111 << "'" << calleeName.getValue()
1112 << "' does not reference a symbol in the current scope";
1113 auto fn = dyn_cast<LLVMFuncOp>(callee);
1114 if (!fn)
1115 return emitOpError() << "'" << calleeName.getValue()
1116 << "' does not reference a valid LLVM function";
1117
1118 if (failed(verifyCallOpDebugInfo(*this, fn)))
1119 return failure();
1120 fnType = fn.getFunctionType();
1121 }
1122
1123 LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
1124 if (!funcType)
1125 return emitOpError("callee does not have a functional type: ") << fnType;
1126
1127 if (funcType.isVarArg() && !getCalleeType())
1128 return emitOpError() << "missing callee type attribute for vararg call";
1129
1130 // Verify that the operand and result types match the callee.
1131
1132 if (!funcType.isVarArg() &&
1133 funcType.getNumParams() != (getNumOperands() - isIndirect))
1134 return emitOpError() << "incorrect number of operands ("
1135 << (getNumOperands() - isIndirect)
1136 << ") for callee (expecting: "
1137 << funcType.getNumParams() << ")";
1138
1139 if (funcType.getNumParams() > (getNumOperands() - isIndirect))
1140 return emitOpError() << "incorrect number of operands ("
1141 << (getNumOperands() - isIndirect)
1142 << ") for varargs callee (expecting at least: "
1143 << funcType.getNumParams() << ")";
1144
1145 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1146 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1147 return emitOpError() << "operand type mismatch for operand " << i << ": "
1148 << getOperand(i + isIndirect).getType()
1149 << " != " << funcType.getParamType(i);
1150
1151 if (getNumResults() == 0 &&
1152 !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1153 return emitOpError() << "expected function call to produce a value";
1154
1155 if (getNumResults() != 0 &&
1156 llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1157 return emitOpError()
1158 << "calling function with void result must not produce values";
1159
1160 if (getNumResults() > 1)
1161 return emitOpError()
1162 << "expected LLVM function call to produce 0 or 1 result";
1163
1164 if (getNumResults() && getResult().getType() != funcType.getReturnType())
1165 return emitOpError() << "result type mismatch: " << getResult().getType()
1166 << " != " << funcType.getReturnType();
1167
1168 return success();
1169}
1170
1171void CallOp::print(OpAsmPrinter &p) {
1172 auto callee = getCallee();
1173 bool isDirect = callee.has_value();
1174
1175 LLVMFunctionType calleeType;
1176 bool isVarArg = false;
1177
1178 if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1179 calleeType = *optionalCalleeType;
1180 isVarArg = calleeType.isVarArg();
1181 }
1182
1183 p << ' ';
1184
1185 // Print calling convention.
1186 if (getCConv() != LLVM::CConv::C)
1187 p << stringifyCConv(getCConv()) << ' ';
1188
1189 // Print the direct callee if present as a function attribute, or an indirect
1190 // callee (first operand) otherwise.
1191 if (isDirect)
1192 p.printSymbolName(callee.value());
1193 else
1194 p << getOperand(0);
1195
1196 auto args = getOperands().drop_front(isDirect ? 0 : 1);
1197 p << '(' << args << ')';
1198
1199 if (isVarArg)
1200 p << " vararg(" << calleeType << ")";
1201
1202 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1203 {getCConvAttrName(), "callee", "callee_type"});
1204
1205 p << " : ";
1206 if (!isDirect)
1207 p << getOperand(0).getType() << ", ";
1208
1209 // Reconstruct the function MLIR function type from operand and result types.
1210 p.printFunctionalType(args.getTypes(), getResultTypes());
1211}
1212
1213/// Parses the type of a call operation and resolves the operands if the parsing
1214/// succeeds. Returns failure otherwise.
1215static ParseResult parseCallTypeAndResolveOperands(
1216 OpAsmParser &parser, OperationState &result, bool isDirect,
1217 ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
1218 SMLoc trailingTypesLoc = parser.getCurrentLocation();
1219 SmallVector<Type> types;
1220 if (parser.parseColonTypeList(result&: types))
1221 return failure();
1222
1223 if (isDirect && types.size() != 1)
1224 return parser.emitError(loc: trailingTypesLoc,
1225 message: "expected direct call to have 1 trailing type");
1226 if (!isDirect && types.size() != 2)
1227 return parser.emitError(loc: trailingTypesLoc,
1228 message: "expected indirect call to have 2 trailing types");
1229
1230 auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
1231 if (!funcType)
1232 return parser.emitError(loc: trailingTypesLoc,
1233 message: "expected trailing function type");
1234 if (funcType.getNumResults() > 1)
1235 return parser.emitError(loc: trailingTypesLoc,
1236 message: "expected function with 0 or 1 result");
1237 if (funcType.getNumResults() == 1 &&
1238 llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
1239 return parser.emitError(loc: trailingTypesLoc,
1240 message: "expected a non-void result type");
1241
1242 // The head element of the types list matches the callee type for
1243 // indirect calls, while the types list is emtpy for direct calls.
1244 // Append the function input types to resolve the call operation
1245 // operands.
1246 llvm::append_range(types, funcType.getInputs());
1247 if (parser.resolveOperands(operands, types, loc: parser.getNameLoc(),
1248 result&: result.operands))
1249 return failure();
1250 if (funcType.getNumResults() != 0)
1251 result.addTypes(funcType.getResults());
1252
1253 return success();
1254}
1255
1256/// Parses an optional function pointer operand before the call argument list
1257/// for indirect calls, or stops parsing at the function identifier otherwise.
1258static ParseResult parseOptionalCallFuncPtr(
1259 OpAsmParser &parser,
1260 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) {
1261 OpAsmParser::UnresolvedOperand funcPtrOperand;
1262 OptionalParseResult parseResult = parser.parseOptionalOperand(result&: funcPtrOperand);
1263 if (parseResult.has_value()) {
1264 if (failed(result: *parseResult))
1265 return *parseResult;
1266 operands.push_back(Elt: funcPtrOperand);
1267 }
1268 return success();
1269}
1270
1271// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
1272// `(` ssa-use-list `)`
1273// ( `vararg(` var-arg-func-type `)` )?
1274// attribute-dict? `:` (type `,`)? function-type
1275ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1276 SymbolRefAttr funcAttr;
1277 TypeAttr calleeType;
1278 SmallVector<OpAsmParser::UnresolvedOperand> operands;
1279
1280 // Default to C Calling Convention if no keyword is provided.
1281 result.addAttribute(
1282 getCConvAttrName(result.name),
1283 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1284 parser, result, LLVM::CConv::C)));
1285
1286 // Parse a function pointer for indirect calls.
1287 if (parseOptionalCallFuncPtr(parser, operands))
1288 return failure();
1289 bool isDirect = operands.empty();
1290
1291 // Parse a function identifier for direct calls.
1292 if (isDirect)
1293 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1294 return failure();
1295
1296 // Parse the function arguments.
1297 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
1298 return failure();
1299
1300 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1301 if (isVarArg) {
1302 if (parser.parseLParen().failed() ||
1303 parser.parseAttribute(calleeType, "callee_type", result.attributes)
1304 .failed() ||
1305 parser.parseRParen().failed())
1306 return failure();
1307 }
1308
1309 if (parser.parseOptionalAttrDict(result.attributes))
1310 return failure();
1311
1312 // Parse the trailing type list and resolve the operands.
1313 return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
1314}
1315
1316LLVMFunctionType CallOp::getCalleeFunctionType() {
1317 if (getCalleeType())
1318 return *getCalleeType();
1319 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1320}
1321
1322///===---------------------------------------------------------------------===//
1323/// LLVM::InvokeOp
1324///===---------------------------------------------------------------------===//
1325
1326void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1327 ValueRange ops, Block *normal, ValueRange normalOps,
1328 Block *unwind, ValueRange unwindOps) {
1329 auto calleeType = func.getFunctionType();
1330 build(builder, state, getCallOpResultTypes(calleeType),
1331 TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
1332 unwindOps, nullptr, nullptr, normal, unwind);
1333}
1334
1335void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1336 FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1337 ValueRange normalOps, Block *unwind,
1338 ValueRange unwindOps) {
1339 build(builder, state, tys,
1340 TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
1341 ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
1342}
1343
1344void InvokeOp::build(OpBuilder &builder, OperationState &state,
1345 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1346 ValueRange ops, Block *normal, ValueRange normalOps,
1347 Block *unwind, ValueRange unwindOps) {
1348 build(builder, state, getCallOpResultTypes(calleeType),
1349 TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
1350 nullptr, normal, unwind);
1351}
1352
1353SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
1354 assert(index < getNumSuccessors() && "invalid successor index");
1355 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1356 : getUnwindDestOperandsMutable());
1357}
1358
1359CallInterfaceCallable InvokeOp::getCallableForCallee() {
1360 // Direct call.
1361 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1362 return calleeAttr;
1363 // Indirect call, callee Value is the first operand.
1364 return getOperand(0);
1365}
1366
1367void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1368 // Direct call.
1369 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1370 auto symRef = callee.get<SymbolRefAttr>();
1371 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1372 }
1373 // Indirect call, callee Value is the first operand.
1374 return setOperand(0, callee.get<Value>());
1375}
1376
1377Operation::operand_range InvokeOp::getArgOperands() {
1378 return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
1379}
1380
1381MutableOperandRange InvokeOp::getArgOperandsMutable() {
1382 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1383 getCalleeOperands().size());
1384}
1385
1386LogicalResult InvokeOp::verify() {
1387 if (getNumResults() > 1)
1388 return emitOpError("must have 0 or 1 result");
1389
1390 Block *unwindDest = getUnwindDest();
1391 if (unwindDest->empty())
1392 return emitError("must have at least one operation in unwind destination");
1393
1394 // In unwind destination, first operation must be LandingpadOp
1395 if (!isa<LandingpadOp>(unwindDest->front()))
1396 return emitError("first operation in unwind destination should be a "
1397 "llvm.landingpad operation");
1398
1399 return success();
1400}
1401
1402void InvokeOp::print(OpAsmPrinter &p) {
1403 auto callee = getCallee();
1404 bool isDirect = callee.has_value();
1405
1406 LLVMFunctionType calleeType;
1407 bool isVarArg = false;
1408
1409 if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
1410 calleeType = *optionalCalleeType;
1411 isVarArg = calleeType.isVarArg();
1412 }
1413
1414 p << ' ';
1415
1416 // Print calling convention.
1417 if (getCConv() != LLVM::CConv::C)
1418 p << stringifyCConv(getCConv()) << ' ';
1419
1420 // Either function name or pointer
1421 if (isDirect)
1422 p.printSymbolName(callee.value());
1423 else
1424 p << getOperand(0);
1425
1426 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
1427 p << " to ";
1428 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
1429 p << " unwind ";
1430 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
1431
1432 if (isVarArg)
1433 p << " vararg(" << calleeType << ")";
1434
1435 p.printOptionalAttrDict((*this)->getAttrs(),
1436 {InvokeOp::getOperandSegmentSizeAttr(), "callee",
1437 "callee_type", InvokeOp::getCConvAttrName()});
1438
1439 p << " : ";
1440 if (!isDirect)
1441 p << getOperand(0).getType() << ", ";
1442 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
1443 getResultTypes());
1444}
1445
1446// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1447// `(` ssa-use-list `)`
1448// `to` bb-id (`[` ssa-use-and-type-list `]`)?
1449// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1450// ( `vararg(` var-arg-func-type `)` )?
1451// attribute-dict? `:` (type `,`)? function-type
1452ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1453 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
1454 SymbolRefAttr funcAttr;
1455 TypeAttr calleeType;
1456 Block *normalDest, *unwindDest;
1457 SmallVector<Value, 4> normalOperands, unwindOperands;
1458 Builder &builder = parser.getBuilder();
1459
1460 // Default to C Calling Convention if no keyword is provided.
1461 result.addAttribute(
1462 getCConvAttrName(result.name),
1463 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1464 parser, result, LLVM::CConv::C)));
1465
1466 // Parse a function pointer for indirect calls.
1467 if (parseOptionalCallFuncPtr(parser, operands))
1468 return failure();
1469 bool isDirect = operands.empty();
1470
1471 // Parse a function identifier for direct calls.
1472 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
1473 return failure();
1474
1475 // Parse the function arguments.
1476 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1477 parser.parseKeyword("to") ||
1478 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
1479 parser.parseKeyword("unwind") ||
1480 parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
1481 return failure();
1482
1483 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1484 if (isVarArg) {
1485 if (parser.parseLParen().failed() ||
1486 parser.parseAttribute(calleeType, "callee_type", result.attributes)
1487 .failed() ||
1488 parser.parseRParen().failed())
1489 return failure();
1490 }
1491
1492 if (parser.parseOptionalAttrDict(result.attributes))
1493 return failure();
1494
1495 // Parse the trailing type list and resolve the function operands.
1496 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
1497 return failure();
1498
1499 result.addSuccessors({normalDest, unwindDest});
1500 result.addOperands(normalOperands);
1501 result.addOperands(unwindOperands);
1502
1503 result.addAttribute(InvokeOp::getOperandSegmentSizeAttr(),
1504 builder.getDenseI32ArrayAttr(
1505 {static_cast<int32_t>(operands.size()),
1506 static_cast<int32_t>(normalOperands.size()),
1507 static_cast<int32_t>(unwindOperands.size())}));
1508 return success();
1509}
1510
1511LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1512 if (getCalleeType())
1513 return *getCalleeType();
1514 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1515}
1516
1517///===----------------------------------------------------------------------===//
1518/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1519///===----------------------------------------------------------------------===//
1520
1521LogicalResult LandingpadOp::verify() {
1522 Value value;
1523 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1524 if (!func.getPersonality())
1525 return emitError(
1526 "llvm.landingpad needs to be in a function with a personality");
1527 }
1528
1529 // Consistency of llvm.landingpad result types is checked in
1530 // LLVMFuncOp::verify().
1531
1532 if (!getCleanup() && getOperands().empty())
1533 return emitError("landingpad instruction expects at least one clause or "
1534 "cleanup attribute");
1535
1536 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1537 value = getOperand(idx);
1538 bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
1539 if (isFilter) {
1540 // FIXME: Verify filter clauses when arrays are appropriately handled
1541 } else {
1542 // catch - global addresses only.
1543 // Bitcast ops should have global addresses as their args.
1544 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1545 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1546 continue;
1547 return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1548 << "global addresses expected as operand to "
1549 "bitcast used in clauses for landingpad";
1550 }
1551 // ZeroOp and AddressOfOp allowed
1552 if (value.getDefiningOp<ZeroOp>())
1553 continue;
1554 if (value.getDefiningOp<AddressOfOp>())
1555 continue;
1556 return emitError("clause #")
1557 << idx << " is not a known constant - null, addressof, bitcast";
1558 }
1559 }
1560 return success();
1561}
1562
1563void LandingpadOp::print(OpAsmPrinter &p) {
1564 p << (getCleanup() ? " cleanup " : " ");
1565
1566 // Clauses
1567 for (auto value : getOperands()) {
1568 // Similar to llvm - if clause is an array type then it is filter
1569 // clause else catch clause
1570 bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
1571 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1572 << value.getType() << ") ";
1573 }
1574
1575 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1576
1577 p << ": " << getType();
1578}
1579
1580// <operation> ::= `llvm.landingpad` `cleanup`?
1581// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1582ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1583 // Check for cleanup
1584 if (succeeded(parser.parseOptionalKeyword("cleanup")))
1585 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1586
1587 // Parse clauses with types
1588 while (succeeded(parser.parseOptionalLParen()) &&
1589 (succeeded(parser.parseOptionalKeyword("filter")) ||
1590 succeeded(parser.parseOptionalKeyword("catch")))) {
1591 OpAsmParser::UnresolvedOperand operand;
1592 Type ty;
1593 if (parser.parseOperand(operand) || parser.parseColon() ||
1594 parser.parseType(ty) ||
1595 parser.resolveOperand(operand, ty, result.operands) ||
1596 parser.parseRParen())
1597 return failure();
1598 }
1599
1600 Type type;
1601 if (parser.parseColon() || parser.parseType(type))
1602 return failure();
1603
1604 result.addTypes(type);
1605 return success();
1606}
1607
1608//===----------------------------------------------------------------------===//
1609// ExtractValueOp
1610//===----------------------------------------------------------------------===//
1611
1612/// Extract the type at `position` in the LLVM IR aggregate type
1613/// `containerType`. Each element of `position` is an index into a nested
1614/// aggregate type. Return the resulting type or emit an error.
1615static Type getInsertExtractValueElementType(
1616 function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1617 ArrayRef<int64_t> position) {
1618 Type llvmType = containerType;
1619 if (!isCompatibleType(type: containerType)) {
1620 emitError("expected LLVM IR Dialect type, got ") << containerType;
1621 return {};
1622 }
1623
1624 // Infer the element type from the structure type: iteratively step inside the
1625 // type by taking the element type, indexed by the position attribute for
1626 // structures. Check the position index before accessing, it is supposed to
1627 // be in bounds.
1628 for (int64_t idx : position) {
1629 if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
1630 if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1631 emitError("position out of bounds: ") << idx;
1632 return {};
1633 }
1634 llvmType = arrayType.getElementType();
1635 } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
1636 if (idx < 0 ||
1637 static_cast<unsigned>(idx) >= structType.getBody().size()) {
1638 emitError("position out of bounds: ") << idx;
1639 return {};
1640 }
1641 llvmType = structType.getBody()[idx];
1642 } else {
1643 emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1644 return {};
1645 }
1646 }
1647 return llvmType;
1648}
1649
1650/// Extract the type at `position` in the wrapped LLVM IR aggregate type
1651/// `containerType`.
1652static Type getInsertExtractValueElementType(Type llvmType,
1653 ArrayRef<int64_t> position) {
1654 for (int64_t idx : position) {
1655 if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
1656 llvmType = structType.getBody()[idx];
1657 else
1658 llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
1659 }
1660 return llvmType;
1661}
1662
1663OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1664 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1665 OpFoldResult result = {};
1666 while (insertValueOp) {
1667 if (getPosition() == insertValueOp.getPosition())
1668 return insertValueOp.getValue();
1669 unsigned min =
1670 std::min(getPosition().size(), insertValueOp.getPosition().size());
1671 // If one is fully prefix of the other, stop propagating back as it will
1672 // miss dependencies. For instance, %3 should not fold to %f0 in the
1673 // following example:
1674 // ```
1675 // %1 = llvm.insertvalue %f0, %0[0, 0] :
1676 // !llvm.array<4 x !llvm.array<4 x f32>>
1677 // %2 = llvm.insertvalue %arr, %1[0] :
1678 // !llvm.array<4 x !llvm.array<4 x f32>>
1679 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
1680 // ```
1681 if (getPosition().take_front(min) ==
1682 insertValueOp.getPosition().take_front(min))
1683 return result;
1684
1685 // If neither a prefix, nor the exact position, we can extract out of the
1686 // value being inserted into. Moreover, we can try again if that operand
1687 // is itself an insertvalue expression.
1688 getContainerMutable().assign(insertValueOp.getContainer());
1689 result = getResult();
1690 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1691 }
1692 return result;
1693}
1694
1695LogicalResult ExtractValueOp::verify() {
1696 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1697 Type valueType = getInsertExtractValueElementType(
1698 emitError, getContainer().getType(), getPosition());
1699 if (!valueType)
1700 return failure();
1701
1702 if (getRes().getType() != valueType)
1703 return emitOpError() << "Type mismatch: extracting from "
1704 << getContainer().getType() << " should produce "
1705 << valueType << " but this op returns "
1706 << getRes().getType();
1707 return success();
1708}
1709
1710void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
1711 Value container, ArrayRef<int64_t> position) {
1712 build(builder, state,
1713 getInsertExtractValueElementType(container.getType(), position),
1714 container, builder.getAttr<DenseI64ArrayAttr>(position));
1715}
1716
1717//===----------------------------------------------------------------------===//
1718// InsertValueOp
1719//===----------------------------------------------------------------------===//
1720
1721/// Infer the value type from the container type and position.
1722static ParseResult
1723parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,
1724 Type containerType,
1725 DenseI64ArrayAttr position) {
1726 valueType = getInsertExtractValueElementType(
1727 [&](StringRef msg) {
1728 return parser.emitError(loc: parser.getCurrentLocation(), message: msg);
1729 },
1730 containerType, position.asArrayRef());
1731 return success(isSuccess: !!valueType);
1732}
1733
1734/// Nothing to print for an inferred type.
1735static void printInsertExtractValueElementType(AsmPrinter &printer,
1736 Operation *op, Type valueType,
1737 Type containerType,
1738 DenseI64ArrayAttr position) {}
1739
1740LogicalResult InsertValueOp::verify() {
1741 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1742 Type valueType = getInsertExtractValueElementType(
1743 emitError, getContainer().getType(), getPosition());
1744 if (!valueType)
1745 return failure();
1746
1747 if (getValue().getType() != valueType)
1748 return emitOpError() << "Type mismatch: cannot insert "
1749 << getValue().getType() << " into "
1750 << getContainer().getType();
1751
1752 return success();
1753}
1754
1755//===----------------------------------------------------------------------===//
1756// ReturnOp
1757//===----------------------------------------------------------------------===//
1758
1759LogicalResult ReturnOp::verify() {
1760 auto parent = (*this)->getParentOfType<LLVMFuncOp>();
1761 if (!parent)
1762 return success();
1763
1764 Type expectedType = parent.getFunctionType().getReturnType();
1765 if (llvm::isa<LLVMVoidType>(expectedType)) {
1766 if (!getArg())
1767 return success();
1768 InFlightDiagnostic diag = emitOpError("expected no operands");
1769 diag.attachNote(parent->getLoc()) << "when returning from function";
1770 return diag;
1771 }
1772 if (!getArg()) {
1773 if (llvm::isa<LLVMVoidType>(expectedType))
1774 return success();
1775 InFlightDiagnostic diag = emitOpError("expected 1 operand");
1776 diag.attachNote(parent->getLoc()) << "when returning from function";
1777 return diag;
1778 }
1779 if (expectedType != getArg().getType()) {
1780 InFlightDiagnostic diag = emitOpError("mismatching result types");
1781 diag.attachNote(parent->getLoc()) << "when returning from function";
1782 return diag;
1783 }
1784 return success();
1785}
1786
1787//===----------------------------------------------------------------------===//
1788// Verifier for LLVM::AddressOfOp.
1789//===----------------------------------------------------------------------===//
1790
1791static Operation *parentLLVMModule(Operation *op) {
1792 Operation *module = op->getParentOp();
1793 while (module && !satisfiesLLVMModule(op: module))
1794 module = module->getParentOp();
1795 assert(module && "unexpected operation outside of a module");
1796 return module;
1797}
1798
1799GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
1800 return dyn_cast_or_null<GlobalOp>(
1801 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
1802}
1803
1804LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
1805 return dyn_cast_or_null<LLVMFuncOp>(
1806 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
1807}
1808
1809LogicalResult
1810AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1811 Operation *symbol =
1812 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
1813
1814 auto global = dyn_cast_or_null<GlobalOp>(symbol);
1815 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
1816
1817 if (!global && !function)
1818 return emitOpError(
1819 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1820
1821 LLVMPointerType type = getType();
1822 if (global && global.getAddrSpace() != type.getAddressSpace())
1823 return emitOpError("pointer address space must match address space of the "
1824 "referenced global");
1825
1826 return success();
1827}
1828
1829//===----------------------------------------------------------------------===//
1830// Verifier for LLVM::ComdatOp.
1831//===----------------------------------------------------------------------===//
1832
1833void ComdatOp::build(OpBuilder &builder, OperationState &result,
1834 StringRef symName) {
1835 result.addAttribute(getSymNameAttrName(result.name),
1836 builder.getStringAttr(symName));
1837 Region *body = result.addRegion();
1838 body->emplaceBlock();
1839}
1840
1841LogicalResult ComdatOp::verifyRegions() {
1842 Region &body = getBody();
1843 for (Operation &op : body.getOps())
1844 if (!isa<ComdatSelectorOp>(op))
1845 return op.emitError(
1846 "only comdat selector symbols can appear in a comdat region");
1847
1848 return success();
1849}
1850
1851//===----------------------------------------------------------------------===//
1852// Builder, printer and verifier for LLVM::GlobalOp.
1853//===----------------------------------------------------------------------===//
1854
1855void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1856 bool isConstant, Linkage linkage, StringRef name,
1857 Attribute value, uint64_t alignment, unsigned addrSpace,
1858 bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
1859 ArrayRef<NamedAttribute> attrs,
1860 DIGlobalVariableExpressionAttr dbgExpr) {
1861 result.addAttribute(getSymNameAttrName(result.name),
1862 builder.getStringAttr(name));
1863 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
1864 if (isConstant)
1865 result.addAttribute(getConstantAttrName(result.name),
1866 builder.getUnitAttr());
1867 if (value)
1868 result.addAttribute(getValueAttrName(result.name), value);
1869 if (dsoLocal)
1870 result.addAttribute(getDsoLocalAttrName(result.name),
1871 builder.getUnitAttr());
1872 if (threadLocal)
1873 result.addAttribute(getThreadLocal_AttrName(result.name),
1874 builder.getUnitAttr());
1875 if (comdat)
1876 result.addAttribute(getComdatAttrName(result.name), comdat);
1877
1878 // Only add an alignment attribute if the "alignment" input
1879 // is different from 0. The value must also be a power of two, but
1880 // this is tested in GlobalOp::verify, not here.
1881 if (alignment != 0)
1882 result.addAttribute(getAlignmentAttrName(result.name),
1883 builder.getI64IntegerAttr(alignment));
1884
1885 result.addAttribute(getLinkageAttrName(result.name),
1886 LinkageAttr::get(builder.getContext(), linkage));
1887 if (addrSpace != 0)
1888 result.addAttribute(getAddrSpaceAttrName(result.name),
1889 builder.getI32IntegerAttr(addrSpace));
1890 result.attributes.append(attrs.begin(), attrs.end());
1891
1892 if (dbgExpr)
1893 result.addAttribute(getDbgExprAttrName(result.name), dbgExpr);
1894
1895 result.addRegion();
1896}
1897
1898void GlobalOp::print(OpAsmPrinter &p) {
1899 p << ' ' << stringifyLinkage(getLinkage()) << ' ';
1900 StringRef visibility = stringifyVisibility(getVisibility_());
1901 if (!visibility.empty())
1902 p << visibility << ' ';
1903 if (getThreadLocal_())
1904 p << "thread_local ";
1905 if (auto unnamedAddr = getUnnamedAddr()) {
1906 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
1907 if (!str.empty())
1908 p << str << ' ';
1909 }
1910 if (getConstant())
1911 p << "constant ";
1912 p.printSymbolName(getSymName());
1913 p << '(';
1914 if (auto value = getValueOrNull())
1915 p.printAttribute(value);
1916 p << ')';
1917 if (auto comdat = getComdat())
1918 p << " comdat(" << *comdat << ')';
1919
1920 // Note that the alignment attribute is printed using the
1921 // default syntax here, even though it is an inherent attribute
1922 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
1923 p.printOptionalAttrDict((*this)->getAttrs(),
1924 {SymbolTable::getSymbolAttrName(),
1925 getGlobalTypeAttrName(), getConstantAttrName(),
1926 getValueAttrName(), getLinkageAttrName(),
1927 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
1928 getVisibility_AttrName(), getComdatAttrName(),
1929 getUnnamedAddrAttrName()});
1930
1931 // Print the trailing type unless it's a string global.
1932 if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
1933 return;
1934 p << " : " << getType();
1935
1936 Region &initializer = getInitializerRegion();
1937 if (!initializer.empty()) {
1938 p << ' ';
1939 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1940 }
1941}
1942
1943static LogicalResult verifyComdat(Operation *op,
1944 std::optional<SymbolRefAttr> attr) {
1945 if (!attr)
1946 return success();
1947
1948 auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
1949 if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
1950 return op->emitError() << "expected comdat symbol";
1951
1952 return success();
1953}
1954
1955// operation ::= `llvm.mlir.global` linkage? visibility?
1956// (`unnamed_addr` | `local_unnamed_addr`)?
1957// `thread_local`? `constant`? `@` identifier
1958// `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
1959// attribute-list? (`:` type)? region?
1960//
1961// The type can be omitted for string attributes, in which case it will be
1962// inferred from the value of the string as [strlen(value) x i8].
1963ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
1964 MLIRContext *ctx = parser.getContext();
1965 // Parse optional linkage, default to External.
1966 result.addAttribute(getLinkageAttrName(result.name),
1967 LLVM::LinkageAttr::get(
1968 ctx, parseOptionalLLVMKeyword<Linkage>(
1969 parser, result, LLVM::Linkage::External)));
1970
1971 // Parse optional visibility, default to Default.
1972 result.addAttribute(getVisibility_AttrName(result.name),
1973 parser.getBuilder().getI64IntegerAttr(
1974 parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
1975 parser, result, LLVM::Visibility::Default)));
1976
1977 // Parse optional UnnamedAddr, default to None.
1978 result.addAttribute(getUnnamedAddrAttrName(result.name),
1979 parser.getBuilder().getI64IntegerAttr(
1980 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
1981 parser, result, LLVM::UnnamedAddr::None)));
1982
1983 if (succeeded(parser.parseOptionalKeyword("thread_local")))
1984 result.addAttribute(getThreadLocal_AttrName(result.name),
1985 parser.getBuilder().getUnitAttr());
1986
1987 if (succeeded(parser.parseOptionalKeyword("constant")))
1988 result.addAttribute(getConstantAttrName(result.name),
1989 parser.getBuilder().getUnitAttr());
1990
1991 StringAttr name;
1992 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
1993 result.attributes) ||
1994 parser.parseLParen())
1995 return failure();
1996
1997 Attribute value;
1998 if (parser.parseOptionalRParen()) {
1999 if (parser.parseAttribute(value, getValueAttrName(result.name),
2000 result.attributes) ||
2001 parser.parseRParen())
2002 return failure();
2003 }
2004
2005 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2006 SymbolRefAttr comdat;
2007 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2008 parser.parseRParen())
2009 return failure();
2010
2011 result.addAttribute(getComdatAttrName(result.name), comdat);
2012 }
2013
2014 SmallVector<Type, 1> types;
2015 if (parser.parseOptionalAttrDict(result.attributes) ||
2016 parser.parseOptionalColonTypeList(types))
2017 return failure();
2018
2019 if (types.size() > 1)
2020 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2021
2022 Region &initRegion = *result.addRegion();
2023 if (types.empty()) {
2024 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
2025 MLIRContext *context = parser.getContext();
2026 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
2027 strAttr.getValue().size());
2028 types.push_back(arrayType);
2029 } else {
2030 return parser.emitError(parser.getNameLoc(),
2031 "type can only be omitted for string globals");
2032 }
2033 } else {
2034 OptionalParseResult parseResult =
2035 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
2036 /*argTypes=*/{});
2037 if (parseResult.has_value() && failed(*parseResult))
2038 return failure();
2039 }
2040
2041 result.addAttribute(getGlobalTypeAttrName(result.name),
2042 TypeAttr::get(types[0]));
2043 return success();
2044}
2045
2046static bool isZeroAttribute(Attribute value) {
2047 if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
2048 return intValue.getValue().isZero();
2049 if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
2050 return fpValue.getValue().isZero();
2051 if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(Val&: value))
2052 return isZeroAttribute(value: splatValue.getSplatValue<Attribute>());
2053 if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
2054 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
2055 if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
2056 return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
2057 return false;
2058}
2059
2060LogicalResult GlobalOp::verify() {
2061 bool validType = isCompatibleOuterType(getType())
2062 ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2063 LLVMMetadataType, LLVMLabelType>(getType())
2064 : llvm::isa<PointerElementTypeInterface>(getType());
2065 if (!validType)
2066 return emitOpError(
2067 "expects type to be a valid element type for an LLVM global");
2068 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
2069 return emitOpError("must appear at the module level");
2070
2071 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
2072 auto type = llvm::dyn_cast<LLVMArrayType>(getType());
2073 IntegerType elementType =
2074 type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
2075 if (!elementType || elementType.getWidth() != 8 ||
2076 type.getNumElements() != strAttr.getValue().size())
2077 return emitOpError(
2078 "requires an i8 array type of the length equal to that of the string "
2079 "attribute");
2080 }
2081
2082 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2083 if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
2084 return emitOpError()
2085 << "this target extension type cannot be used in a global";
2086
2087 if (Attribute value = getValueOrNull())
2088 return emitOpError() << "global with target extension type can only be "
2089 "initialized with zero-initializer";
2090 }
2091
2092 if (getLinkage() == Linkage::Common) {
2093 if (Attribute value = getValueOrNull()) {
2094 if (!isZeroAttribute(value)) {
2095 return emitOpError()
2096 << "expected zero value for '"
2097 << stringifyLinkage(Linkage::Common) << "' linkage";
2098 }
2099 }
2100 }
2101
2102 if (getLinkage() == Linkage::Appending) {
2103 if (!llvm::isa<LLVMArrayType>(getType())) {
2104 return emitOpError() << "expected array type for '"
2105 << stringifyLinkage(Linkage::Appending)
2106 << "' linkage";
2107 }
2108 }
2109
2110 if (failed(verifyComdat(*this, getComdat())))
2111 return failure();
2112
2113 std::optional<uint64_t> alignAttr = getAlignment();
2114 if (alignAttr.has_value()) {
2115 uint64_t value = alignAttr.value();
2116 if (!llvm::isPowerOf2_64(value))
2117 return emitError() << "alignment attribute is not a power of 2";
2118 }
2119
2120 return success();
2121}
2122
2123LogicalResult GlobalOp::verifyRegions() {
2124 if (Block *b = getInitializerBlock()) {
2125 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2126 if (ret.operand_type_begin() == ret.operand_type_end())
2127 return emitOpError("initializer region cannot return void");
2128 if (*ret.operand_type_begin() != getType())
2129 return emitOpError("initializer region type ")
2130 << *ret.operand_type_begin() << " does not match global type "
2131 << getType();
2132
2133 for (Operation &op : *b) {
2134 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2135 if (!iface || !iface.hasNoEffect())
2136 return op.emitError()
2137 << "ops with side effects not allowed in global initializers";
2138 }
2139
2140 if (getValueOrNull())
2141 return emitOpError("cannot have both initializer value and region");
2142 }
2143
2144 return success();
2145}
2146
2147//===----------------------------------------------------------------------===//
2148// LLVM::GlobalCtorsOp
2149//===----------------------------------------------------------------------===//
2150
2151LogicalResult
2152GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2153 for (Attribute ctor : getCtors()) {
2154 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
2155 symbolTable)))
2156 return failure();
2157 }
2158 return success();
2159}
2160
2161LogicalResult GlobalCtorsOp::verify() {
2162 if (getCtors().size() != getPriorities().size())
2163 return emitError(
2164 "mismatch between the number of ctors and the number of priorities");
2165 return success();
2166}
2167
2168//===----------------------------------------------------------------------===//
2169// LLVM::GlobalDtorsOp
2170//===----------------------------------------------------------------------===//
2171
2172LogicalResult
2173GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2174 for (Attribute dtor : getDtors()) {
2175 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
2176 symbolTable)))
2177 return failure();
2178 }
2179 return success();
2180}
2181
2182LogicalResult GlobalDtorsOp::verify() {
2183 if (getDtors().size() != getPriorities().size())
2184 return emitError(
2185 "mismatch between the number of dtors and the number of priorities");
2186 return success();
2187}
2188
2189//===----------------------------------------------------------------------===//
2190// ShuffleVectorOp
2191//===----------------------------------------------------------------------===//
2192
2193void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2194 Value v2, DenseI32ArrayAttr mask,
2195 ArrayRef<NamedAttribute> attrs) {
2196 auto containerType = v1.getType();
2197 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
2198 mask.size(),
2199 LLVM::isScalableVectorType(containerType));
2200 build(builder, state, vType, v1, v2, mask);
2201 state.addAttributes(attrs);
2202}
2203
2204void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2205 Value v2, ArrayRef<int32_t> mask) {
2206 build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
2207}
2208
2209/// Build the result type of a shuffle vector operation.
2210static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2211 Type &resType, DenseI32ArrayAttr mask) {
2212 if (!LLVM::isCompatibleVectorType(type: v1Type))
2213 return parser.emitError(loc: parser.getCurrentLocation(),
2214 message: "expected an LLVM compatible vector type");
2215 resType = LLVM::getVectorType(LLVM::getVectorElementType(type: v1Type), mask.size(),
2216 LLVM::isScalableVectorType(vectorType: v1Type));
2217 return success();
2218}
2219
2220/// Nothing to do when the result type is inferred.
2221static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2222 Type resType, DenseI32ArrayAttr mask) {}
2223
2224LogicalResult ShuffleVectorOp::verify() {
2225 if (LLVM::isScalableVectorType(getV1().getType()) &&
2226 llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
2227 return emitOpError("expected a splat operation for scalable vectors");
2228 return success();
2229}
2230
2231//===----------------------------------------------------------------------===//
2232// Implementations for LLVM::LLVMFuncOp.
2233//===----------------------------------------------------------------------===//
2234
2235// Add the entry block to the function.
2236Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
2237 assert(empty() && "function already has an entry block");
2238 OpBuilder::InsertionGuard g(builder);
2239 Block *entry = builder.createBlock(&getBody());
2240
2241 // FIXME: Allow passing in proper locations for the entry arguments.
2242 LLVMFunctionType type = getFunctionType();
2243 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2244 entry->addArgument(type.getParamType(i), getLoc());
2245 return entry;
2246}
2247
2248void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2249 StringRef name, Type type, LLVM::Linkage linkage,
2250 bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
2251 ArrayRef<NamedAttribute> attrs,
2252 ArrayRef<DictionaryAttr> argAttrs,
2253 std::optional<uint64_t> functionEntryCount) {
2254 result.addRegion();
2255 result.addAttribute(SymbolTable::getSymbolAttrName(),
2256 builder.getStringAttr(name));
2257 result.addAttribute(getFunctionTypeAttrName(result.name),
2258 TypeAttr::get(type));
2259 result.addAttribute(getLinkageAttrName(result.name),
2260 LinkageAttr::get(builder.getContext(), linkage));
2261 result.addAttribute(getCConvAttrName(result.name),
2262 CConvAttr::get(builder.getContext(), cconv));
2263 result.attributes.append(attrs.begin(), attrs.end());
2264 if (dsoLocal)
2265 result.addAttribute(getDsoLocalAttrName(result.name),
2266 builder.getUnitAttr());
2267 if (comdat)
2268 result.addAttribute(getComdatAttrName(result.name), comdat);
2269 if (functionEntryCount)
2270 result.addAttribute(getFunctionEntryCountAttrName(result.name),
2271 builder.getI64IntegerAttr(functionEntryCount.value()));
2272 if (argAttrs.empty())
2273 return;
2274
2275 assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
2276 "expected as many argument attribute lists as arguments");
2277 function_interface_impl::addArgAndResultAttrs(
2278 builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
2279 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2280}
2281
2282// Builds an LLVM function type from the given lists of input and output types.
2283// Returns a null type if any of the types provided are non-LLVM types, or if
2284// there is more than one output type.
2285static Type
2286buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
2287 ArrayRef<Type> outputs,
2288 function_interface_impl::VariadicFlag variadicFlag) {
2289 Builder &b = parser.getBuilder();
2290 if (outputs.size() > 1) {
2291 parser.emitError(loc, message: "failed to construct function type: expected zero or "
2292 "one function result");
2293 return {};
2294 }
2295
2296 // Convert inputs to LLVM types, exit early on error.
2297 SmallVector<Type, 4> llvmInputs;
2298 for (auto t : inputs) {
2299 if (!isCompatibleType(type: t)) {
2300 parser.emitError(loc, message: "failed to construct function type: expected LLVM "
2301 "type for function arguments");
2302 return {};
2303 }
2304 llvmInputs.push_back(Elt: t);
2305 }
2306
2307 // No output is denoted as "void" in LLVM type system.
2308 Type llvmOutput =
2309 outputs.empty() ? LLVMVoidType::get(ctx: b.getContext()) : outputs.front();
2310 if (!isCompatibleType(type: llvmOutput)) {
2311 parser.emitError(loc, message: "failed to construct function type: expected LLVM "
2312 "type for function results")
2313 << llvmOutput;
2314 return {};
2315 }
2316 return LLVMFunctionType::get(llvmOutput, llvmInputs,
2317 variadicFlag.isVariadic());
2318}
2319
2320// Parses an LLVM function.
2321//
2322// operation ::= `llvm.func` linkage? cconv? function-signature
2323// (`comdat(` symbol-ref-id `)`)?
2324// function-attributes?
2325// function-body
2326//
2327ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2328 // Default to external linkage if no keyword is provided.
2329 result.addAttribute(
2330 getLinkageAttrName(result.name),
2331 LinkageAttr::get(parser.getContext(),
2332 parseOptionalLLVMKeyword<Linkage>(
2333 parser, result, LLVM::Linkage::External)));
2334
2335 // Parse optional visibility, default to Default.
2336 result.addAttribute(getVisibility_AttrName(result.name),
2337 parser.getBuilder().getI64IntegerAttr(
2338 parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2339 parser, result, LLVM::Visibility::Default)));
2340
2341 // Parse optional UnnamedAddr, default to None.
2342 result.addAttribute(getUnnamedAddrAttrName(result.name),
2343 parser.getBuilder().getI64IntegerAttr(
2344 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2345 parser, result, LLVM::UnnamedAddr::None)));
2346
2347 // Default to C Calling Convention if no keyword is provided.
2348 result.addAttribute(
2349 getCConvAttrName(result.name),
2350 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
2351 parser, result, LLVM::CConv::C)));
2352
2353 StringAttr nameAttr;
2354 SmallVector<OpAsmParser::Argument> entryArgs;
2355 SmallVector<DictionaryAttr> resultAttrs;
2356 SmallVector<Type> resultTypes;
2357 bool isVariadic;
2358
2359 auto signatureLocation = parser.getCurrentLocation();
2360 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2361 result.attributes) ||
2362 function_interface_impl::parseFunctionSignature(
2363 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
2364 resultAttrs))
2365 return failure();
2366
2367 SmallVector<Type> argTypes;
2368 for (auto &arg : entryArgs)
2369 argTypes.push_back(arg.type);
2370 auto type =
2371 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2372 function_interface_impl::VariadicFlag(isVariadic));
2373 if (!type)
2374 return failure();
2375 result.addAttribute(getFunctionTypeAttrName(result.name),
2376 TypeAttr::get(type));
2377
2378 if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
2379 int64_t minRange, maxRange;
2380 if (parser.parseLParen() || parser.parseInteger(minRange) ||
2381 parser.parseComma() || parser.parseInteger(maxRange) ||
2382 parser.parseRParen())
2383 return failure();
2384 auto intTy = IntegerType::get(parser.getContext(), 32);
2385 result.addAttribute(
2386 getVscaleRangeAttrName(result.name),
2387 LLVM::VScaleRangeAttr::get(parser.getContext(),
2388 IntegerAttr::get(intTy, minRange),
2389 IntegerAttr::get(intTy, maxRange)));
2390 }
2391 // Parse the optional comdat selector.
2392 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2393 SymbolRefAttr comdat;
2394 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2395 parser.parseRParen())
2396 return failure();
2397
2398 result.addAttribute(getComdatAttrName(result.name), comdat);
2399 }
2400
2401 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2402 return failure();
2403 function_interface_impl::addArgAndResultAttrs(
2404 parser.getBuilder(), result, entryArgs, resultAttrs,
2405 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2406
2407 auto *body = result.addRegion();
2408 OptionalParseResult parseResult =
2409 parser.parseOptionalRegion(*body, entryArgs);
2410 return failure(parseResult.has_value() && failed(*parseResult));
2411}
2412
2413// Print the LLVMFuncOp. Collects argument and result types and passes them to
2414// helper functions. Drops "void" result since it cannot be parsed back. Skips
2415// the external linkage since it is the default value.
2416void LLVMFuncOp::print(OpAsmPrinter &p) {
2417 p << ' ';
2418 if (getLinkage() != LLVM::Linkage::External)
2419 p << stringifyLinkage(getLinkage()) << ' ';
2420 StringRef visibility = stringifyVisibility(getVisibility_());
2421 if (!visibility.empty())
2422 p << visibility << ' ';
2423 if (auto unnamedAddr = getUnnamedAddr()) {
2424 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2425 if (!str.empty())
2426 p << str << ' ';
2427 }
2428 if (getCConv() != LLVM::CConv::C)
2429 p << stringifyCConv(getCConv()) << ' ';
2430
2431 p.printSymbolName(getName());
2432
2433 LLVMFunctionType fnType = getFunctionType();
2434 SmallVector<Type, 8> argTypes;
2435 SmallVector<Type, 1> resTypes;
2436 argTypes.reserve(fnType.getNumParams());
2437 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2438 argTypes.push_back(fnType.getParamType(i));
2439
2440 Type returnType = fnType.getReturnType();
2441 if (!llvm::isa<LLVMVoidType>(returnType))
2442 resTypes.push_back(returnType);
2443
2444 function_interface_impl::printFunctionSignature(p, *this, argTypes,
2445 isVarArg(), resTypes);
2446
2447 // Print vscale range if present
2448 if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
2449 p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
2450 << vscale->getMaxRange().getInt() << ')';
2451
2452 // Print the optional comdat selector.
2453 if (auto comdat = getComdat())
2454 p << " comdat(" << *comdat << ')';
2455
2456 function_interface_impl::printFunctionAttributes(
2457 p, *this,
2458 {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2459 getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
2460 getComdatAttrName(), getUnnamedAddrAttrName(),
2461 getVscaleRangeAttrName()});
2462
2463 // Print the body if this is not an external function.
2464 Region &body = getBody();
2465 if (!body.empty()) {
2466 p << ' ';
2467 p.printRegion(body, /*printEntryBlockArgs=*/false,
2468 /*printBlockTerminators=*/true);
2469 }
2470}
2471
2472// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2473// - functions don't have 'common' linkage
2474// - external functions have 'external' or 'extern_weak' linkage;
2475// - vararg is (currently) only supported for external functions;
2476LogicalResult LLVMFuncOp::verify() {
2477 if (getLinkage() == LLVM::Linkage::Common)
2478 return emitOpError() << "functions cannot have '"
2479 << stringifyLinkage(LLVM::Linkage::Common)
2480 << "' linkage";
2481
2482 if (failed(verifyComdat(*this, getComdat())))
2483 return failure();
2484
2485 if (isExternal()) {
2486 if (getLinkage() != LLVM::Linkage::External &&
2487 getLinkage() != LLVM::Linkage::ExternWeak)
2488 return emitOpError() << "external functions must have '"
2489 << stringifyLinkage(LLVM::Linkage::External)
2490 << "' or '"
2491 << stringifyLinkage(LLVM::Linkage::ExternWeak)
2492 << "' linkage";
2493 return success();
2494 }
2495
2496 Type landingpadResultTy;
2497 StringRef diagnosticMessage;
2498 bool isLandingpadTypeConsistent =
2499 !walk([&](Operation *op) {
2500 const auto checkType = [&](Type type, StringRef errorMessage) {
2501 if (!landingpadResultTy) {
2502 landingpadResultTy = type;
2503 return WalkResult::advance();
2504 }
2505 if (landingpadResultTy != type) {
2506 diagnosticMessage = errorMessage;
2507 return WalkResult::interrupt();
2508 }
2509 return WalkResult::advance();
2510 };
2511 return TypeSwitch<Operation *, WalkResult>(op)
2512 .Case<LandingpadOp>([&](auto landingpad) {
2513 constexpr StringLiteral errorMessage =
2514 "'llvm.landingpad' should have a consistent result type "
2515 "inside a function";
2516 return checkType(landingpad.getType(), errorMessage);
2517 })
2518 .Case<ResumeOp>([&](auto resume) {
2519 constexpr StringLiteral errorMessage =
2520 "'llvm.resume' should have a consistent input type inside a "
2521 "function";
2522 return checkType(resume.getValue().getType(), errorMessage);
2523 })
2524 .Default([](auto) { return WalkResult::skip(); });
2525 }).wasInterrupted();
2526 if (!isLandingpadTypeConsistent) {
2527 assert(!diagnosticMessage.empty() &&
2528 "Expecting a non-empty diagnostic message");
2529 return emitError(diagnosticMessage);
2530 }
2531
2532 return success();
2533}
2534
2535/// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2536/// - entry block arguments are of LLVM types.
2537LogicalResult LLVMFuncOp::verifyRegions() {
2538 if (isExternal())
2539 return success();
2540
2541 unsigned numArguments = getFunctionType().getNumParams();
2542 Block &entryBlock = front();
2543 for (unsigned i = 0; i < numArguments; ++i) {
2544 Type argType = entryBlock.getArgument(i).getType();
2545 if (!isCompatibleType(argType))
2546 return emitOpError("entry block argument #")
2547 << i << " is not of LLVM type";
2548 }
2549
2550 return success();
2551}
2552
2553Region *LLVMFuncOp::getCallableRegion() {
2554 if (isExternal())
2555 return nullptr;
2556 return &getBody();
2557}
2558
2559//===----------------------------------------------------------------------===//
2560// ZeroOp.
2561//===----------------------------------------------------------------------===//
2562
2563LogicalResult LLVM::ZeroOp::verify() {
2564 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
2565 if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
2566 return emitOpError()
2567 << "target extension type does not support zero-initializer";
2568
2569 return success();
2570}
2571
2572//===----------------------------------------------------------------------===//
2573// ConstantOp.
2574//===----------------------------------------------------------------------===//
2575
2576LogicalResult LLVM::ConstantOp::verify() {
2577 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
2578 auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
2579 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
2580 !arrayType.getElementType().isInteger(8)) {
2581 return emitOpError() << "expected array type of "
2582 << sAttr.getValue().size()
2583 << " i8 elements for the string constant";
2584 }
2585 return success();
2586 }
2587 if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
2588 if (structType.getBody().size() != 2 ||
2589 structType.getBody()[0] != structType.getBody()[1]) {
2590 return emitError() << "expected struct type with two elements of the "
2591 "same type, the type of a complex constant";
2592 }
2593
2594 auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
2595 if (!arrayAttr || arrayAttr.size() != 2) {
2596 return emitOpError() << "expected array attribute with two elements, "
2597 "representing a complex constant";
2598 }
2599 auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
2600 auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
2601 if (!re || !im || re.getType() != im.getType()) {
2602 return emitOpError()
2603 << "expected array attribute with two elements of the same type";
2604 }
2605
2606 Type elementType = structType.getBody()[0];
2607 if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
2608 elementType)) {
2609 return emitError()
2610 << "expected struct element types to be floating point type or "
2611 "integer type";
2612 }
2613 return success();
2614 }
2615 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2616 return emitOpError() << "does not support target extension type.";
2617 }
2618 if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
2619 return emitOpError()
2620 << "only supports integer, float, string or elements attributes";
2621 if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
2622 if (!llvm::isa<IntegerType>(getType()))
2623 return emitOpError() << "expected integer type";
2624 }
2625 if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
2626 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
2627 unsigned floatWidth = APFloat::getSizeInBits(sem);
2628 if (auto floatTy = dyn_cast<FloatType>(getType())) {
2629 if (floatTy.getWidth() != floatWidth) {
2630 return emitOpError() << "expected float type of width " << floatWidth;
2631 }
2632 }
2633 // See the comment for getLLVMConstant for more details about why 8-bit
2634 // floats can be represented by integers.
2635 if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
2636 return emitOpError() << "expected integer type of width " << floatWidth;
2637 }
2638 }
2639 if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
2640 if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
2641 !isa<LLVM::LLVMFixedVectorType>(getType()) &&
2642 !isa<LLVM::LLVMScalableVectorType>(getType()))
2643 return emitOpError() << "expected vector or array type";
2644 }
2645 return success();
2646}
2647
2648bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
2649 // The value's type must be the same as the provided type.
2650 auto typedAttr = dyn_cast<TypedAttr>(value);
2651 if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
2652 return false;
2653 // The value's type must be an LLVM compatible type.
2654 if (!isCompatibleType(type))
2655 return false;
2656 // TODO: Add support for additional attributes kinds once needed.
2657 return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
2658}
2659
2660ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
2661 Type type, Location loc) {
2662 if (isBuildableWith(value, type))
2663 return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value));
2664 return nullptr;
2665}
2666
2667// Constant op constant-folds to its value.
2668OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
2669
2670//===----------------------------------------------------------------------===//
2671// AtomicRMWOp
2672//===----------------------------------------------------------------------===//
2673
2674void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
2675 AtomicBinOp binOp, Value ptr, Value val,
2676 AtomicOrdering ordering, StringRef syncscope,
2677 unsigned alignment, bool isVolatile) {
2678 build(builder, state, val.getType(), binOp, ptr, val, ordering,
2679 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
2680 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
2681 /*access_groups=*/nullptr,
2682 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
2683}
2684
2685LogicalResult AtomicRMWOp::verify() {
2686 auto valType = getVal().getType();
2687 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
2688 getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
2689 if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2690 return emitOpError("expected LLVM IR floating point type");
2691 } else if (getBinOp() == AtomicBinOp::xchg) {
2692 if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/true))
2693 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2694 } else {
2695 auto intType = llvm::dyn_cast<IntegerType>(valType);
2696 unsigned intBitWidth = intType ? intType.getWidth() : 0;
2697 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2698 intBitWidth != 64)
2699 return emitOpError("expected LLVM IR integer type");
2700 }
2701
2702 if (static_cast<unsigned>(getOrdering()) <
2703 static_cast<unsigned>(AtomicOrdering::monotonic))
2704 return emitOpError() << "expected at least '"
2705 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2706 << "' ordering";
2707
2708 return success();
2709}
2710
2711//===----------------------------------------------------------------------===//
2712// AtomicCmpXchgOp
2713//===----------------------------------------------------------------------===//
2714
2715/// Returns an LLVM struct type that contains a value type and a boolean type.
2716static LLVMStructType getValAndBoolStructType(Type valType) {
2717 auto boolType = IntegerType::get(valType.getContext(), 1);
2718 return LLVMStructType::getLiteral(context: valType.getContext(), types: {valType, boolType});
2719}
2720
2721void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
2722 Value ptr, Value cmp, Value val,
2723 AtomicOrdering successOrdering,
2724 AtomicOrdering failureOrdering, StringRef syncscope,
2725 unsigned alignment, bool isWeak, bool isVolatile) {
2726 build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
2727 successOrdering, failureOrdering,
2728 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
2729 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
2730 isVolatile, /*access_groups=*/nullptr,
2731 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
2732}
2733
2734LogicalResult AtomicCmpXchgOp::verify() {
2735 auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
2736 if (!ptrType)
2737 return emitOpError("expected LLVM IR pointer type for operand #0");
2738 auto valType = getVal().getType();
2739 if (!isTypeCompatibleWithAtomicOp(valType,
2740 /*isPointerTypeAllowed=*/true))
2741 return emitOpError("unexpected LLVM IR type");
2742 if (getSuccessOrdering() < AtomicOrdering::monotonic ||
2743 getFailureOrdering() < AtomicOrdering::monotonic)
2744 return emitOpError("ordering must be at least 'monotonic'");
2745 if (getFailureOrdering() == AtomicOrdering::release ||
2746 getFailureOrdering() == AtomicOrdering::acq_rel)
2747 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2748 return success();
2749}
2750
2751//===----------------------------------------------------------------------===//
2752// FenceOp
2753//===----------------------------------------------------------------------===//
2754
2755void FenceOp::build(OpBuilder &builder, OperationState &state,
2756 AtomicOrdering ordering, StringRef syncscope) {
2757 build(builder, state, ordering,
2758 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
2759}
2760
2761LogicalResult FenceOp::verify() {
2762 if (getOrdering() == AtomicOrdering::not_atomic ||
2763 getOrdering() == AtomicOrdering::unordered ||
2764 getOrdering() == AtomicOrdering::monotonic)
2765 return emitOpError("can be given only acquire, release, acq_rel, "
2766 "and seq_cst orderings");
2767 return success();
2768}
2769
2770//===----------------------------------------------------------------------===//
2771// Verifier for extension ops
2772//===----------------------------------------------------------------------===//
2773
2774/// Verifies that the given extension operation operates on consistent scalars
2775/// or vectors, and that the target width is larger than the input width.
2776template <class ExtOp>
2777static LogicalResult verifyExtOp(ExtOp op) {
2778 IntegerType inputType, outputType;
2779 if (isCompatibleVectorType(op.getArg().getType())) {
2780 if (!isCompatibleVectorType(op.getResult().getType()))
2781 return op.emitError(
2782 "input type is a vector but output type is an integer");
2783 if (getVectorNumElements(op.getArg().getType()) !=
2784 getVectorNumElements(op.getResult().getType()))
2785 return op.emitError("input and output vectors are of incompatible shape");
2786 // Because this is a CastOp, the element of vectors is guaranteed to be an
2787 // integer.
2788 inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
2789 outputType =
2790 cast<IntegerType>(getVectorElementType(op.getResult().getType()));
2791 } else {
2792 // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
2793 // an integer.
2794 inputType = cast<IntegerType>(op.getArg().getType());
2795 outputType = dyn_cast<IntegerType>(op.getResult().getType());
2796 if (!outputType)
2797 return op.emitError(
2798 "input type is an integer but output type is a vector");
2799 }
2800
2801 if (outputType.getWidth() <= inputType.getWidth())
2802 return op.emitError("integer width of the output type is smaller or "
2803 "equal to the integer width of the input type");
2804 return success();
2805}
2806
2807//===----------------------------------------------------------------------===//
2808// ZExtOp
2809//===----------------------------------------------------------------------===//
2810
2811LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
2812
2813OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
2814 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
2815 if (!arg)
2816 return {};
2817
2818 size_t targetSize = cast<IntegerType>(getType()).getWidth();
2819 return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
2820}
2821
2822//===----------------------------------------------------------------------===//
2823// SExtOp
2824//===----------------------------------------------------------------------===//
2825
2826LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
2827
2828//===----------------------------------------------------------------------===//
2829// Folder and verifier for LLVM::BitcastOp
2830//===----------------------------------------------------------------------===//
2831
2832/// Folds a cast op that can be chained.
2833template <typename T>
2834static OpFoldResult foldChainableCast(T castOp,
2835 typename T::FoldAdaptor adaptor) {
2836 // cast(x : T0, T0) -> x
2837 if (castOp.getArg().getType() == castOp.getType())
2838 return castOp.getArg();
2839 if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
2840 // cast(cast(x : T0, T1), T0) -> x
2841 if (prev.getArg().getType() == castOp.getType())
2842 return prev.getArg();
2843 // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
2844 castOp.getArgMutable().set(prev.getArg());
2845 return Value{castOp};
2846 }
2847 return {};
2848}
2849
2850OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
2851 return foldChainableCast(*this, adaptor);
2852}
2853
2854LogicalResult LLVM::BitcastOp::verify() {
2855 auto resultType = llvm::dyn_cast<LLVMPointerType>(
2856 extractVectorElementType(getResult().getType()));
2857 auto sourceType = llvm::dyn_cast<LLVMPointerType>(
2858 extractVectorElementType(getArg().getType()));
2859
2860 // If one of the types is a pointer (or vector of pointers), then
2861 // both source and result type have to be pointers.
2862 if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
2863 return emitOpError("can only cast pointers from and to pointers");
2864
2865 if (!resultType)
2866 return success();
2867
2868 auto isVector =
2869 llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
2870
2871 // Due to bitcast requiring both operands to be of the same size, it is not
2872 // possible for only one of the two to be a pointer of vectors.
2873 if (isVector(getResult().getType()) && !isVector(getArg().getType()))
2874 return emitOpError("cannot cast pointer to vector of pointers");
2875
2876 if (!isVector(getResult().getType()) && isVector(getArg().getType()))
2877 return emitOpError("cannot cast vector of pointers to pointer");
2878
2879 // Bitcast cannot cast between pointers of different address spaces.
2880 // 'llvm.addrspacecast' must be used for this purpose instead.
2881 if (resultType.getAddressSpace() != sourceType.getAddressSpace())
2882 return emitOpError("cannot cast pointers of different address spaces, "
2883 "use 'llvm.addrspacecast' instead");
2884
2885 return success();
2886}
2887
2888//===----------------------------------------------------------------------===//
2889// Folder for LLVM::AddrSpaceCastOp
2890//===----------------------------------------------------------------------===//
2891
2892OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
2893 return foldChainableCast(*this, adaptor);
2894}
2895
2896//===----------------------------------------------------------------------===//
2897// Folder for LLVM::GEPOp
2898//===----------------------------------------------------------------------===//
2899
2900OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
2901 GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
2902 adaptor.getDynamicIndices());
2903
2904 // gep %x:T, 0 -> %x
2905 if (getBase().getType() == getType() && indices.size() == 1)
2906 if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
2907 if (integer.getValue().isZero())
2908 return getBase();
2909
2910 // Canonicalize any dynamic indices of constant value to constant indices.
2911 bool changed = false;
2912 SmallVector<GEPArg> gepArgs;
2913 for (auto iter : llvm::enumerate(indices)) {
2914 auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
2915 // Constant indices can only be int32_t, so if integer does not fit we
2916 // are forced to keep it dynamic, despite being a constant.
2917 if (!indices.isDynamicIndex(iter.index()) || !integer ||
2918 !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
2919
2920 PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
2921 if (Value val = llvm::dyn_cast_if_present<Value>(existing))
2922 gepArgs.emplace_back(val);
2923 else
2924 gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
2925
2926 continue;
2927 }
2928
2929 changed = true;
2930 gepArgs.emplace_back(integer.getInt());
2931 }
2932 if (changed) {
2933 SmallVector<int32_t> rawConstantIndices;
2934 SmallVector<Value> dynamicIndices;
2935 destructureIndices(getElemType(), gepArgs, rawConstantIndices,
2936 dynamicIndices);
2937
2938 getDynamicIndicesMutable().assign(dynamicIndices);
2939 setRawConstantIndices(rawConstantIndices);
2940 return Value{*this};
2941 }
2942
2943 return {};
2944}
2945
2946//===----------------------------------------------------------------------===//
2947// ShlOp
2948//===----------------------------------------------------------------------===//
2949
2950OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
2951 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
2952 if (!rhs)
2953 return {};
2954
2955 if (rhs.getValue().getZExtValue() >=
2956 getLhs().getType().getIntOrFloatBitWidth())
2957 return {}; // TODO: Fold into poison.
2958
2959 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
2960 if (!lhs)
2961 return {};
2962
2963 return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
2964}
2965
2966//===----------------------------------------------------------------------===//
2967// OrOp
2968//===----------------------------------------------------------------------===//
2969
2970OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
2971 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
2972 if (!lhs)
2973 return {};
2974
2975 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
2976 if (!rhs)
2977 return {};
2978
2979 return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
2980}
2981
2982//===----------------------------------------------------------------------===//
2983// CallIntrinsicOp
2984//===----------------------------------------------------------------------===//
2985
2986LogicalResult CallIntrinsicOp::verify() {
2987 if (!getIntrin().starts_with("llvm."))
2988 return emitOpError() << "intrinsic name must start with 'llvm.'";
2989 return success();
2990}
2991
2992//===----------------------------------------------------------------------===//
2993// OpAsmDialectInterface
2994//===----------------------------------------------------------------------===//
2995
2996namespace {
2997struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
2998 using OpAsmDialectInterface::OpAsmDialectInterface;
2999
3000 AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
3001 return TypeSwitch<Attribute, AliasResult>(attr)
3002 .Case<AccessGroupAttr, AliasScopeAttr, AliasScopeDomainAttr,
3003 DIBasicTypeAttr, DICompileUnitAttr, DICompositeTypeAttr,
3004 DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr,
3005 DIGlobalVariableExpressionAttr, DILabelAttr, DILexicalBlockAttr,
3006 DILexicalBlockFileAttr, DILocalVariableAttr, DIModuleAttr,
3007 DINamespaceAttr, DINullTypeAttr, DISubprogramAttr,
3008 DISubroutineTypeAttr, LoopAnnotationAttr, LoopVectorizeAttr,
3009 LoopInterleaveAttr, LoopUnrollAttr, LoopUnrollAndJamAttr,
3010 LoopLICMAttr, LoopDistributeAttr, LoopPipelineAttr,
3011 LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr, TBAATagAttr,
3012 TBAATypeDescriptorAttr>([&](auto attr) {
3013 os << decltype(attr)::getMnemonic();
3014 return AliasResult::OverridableAlias;
3015 })
3016 .Default([](Attribute) { return AliasResult::NoAlias; });
3017 }
3018};
3019} // namespace
3020
3021//===----------------------------------------------------------------------===//
3022// LinkerOptionsOp
3023//===----------------------------------------------------------------------===//
3024
3025LogicalResult LinkerOptionsOp::verify() {
3026 if (mlir::Operation *parentOp = (*this)->getParentOp();
3027 parentOp && !satisfiesLLVMModule(parentOp))
3028 return emitOpError("must appear at the module level");
3029 return success();
3030}
3031
3032//===----------------------------------------------------------------------===//
3033// LLVMDialect initialization, type parsing, and registration.
3034//===----------------------------------------------------------------------===//
3035
3036void LLVMDialect::initialize() {
3037 registerAttributes();
3038
3039 // clang-format off
3040 addTypes<LLVMVoidType,
3041 LLVMPPCFP128Type,
3042 LLVMX86MMXType,
3043 LLVMTokenType,
3044 LLVMLabelType,
3045 LLVMMetadataType,
3046 LLVMStructType>();
3047 // clang-format on
3048 registerTypes();
3049
3050 addOperations<
3051#define GET_OP_LIST
3052#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
3053 ,
3054#define GET_OP_LIST
3055#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
3056 >();
3057
3058 // Support unknown operations because not all LLVM operations are registered.
3059 allowUnknownOperations();
3060 // clang-format off
3061 addInterfaces<LLVMOpAsmDialectInterface>();
3062 // clang-format on
3063 detail::addLLVMInlinerInterface(this);
3064}
3065
3066#define GET_OP_CLASSES
3067#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
3068
3069#define GET_OP_CLASSES
3070#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
3071
3072LogicalResult LLVMDialect::verifyDataLayoutString(
3073 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
3074 llvm::Expected<llvm::DataLayout> maybeDataLayout =
3075 llvm::DataLayout::parse(descr);
3076 if (maybeDataLayout)
3077 return success();
3078
3079 std::string message;
3080 llvm::raw_string_ostream messageStream(message);
3081 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
3082 reportError("invalid data layout descriptor: " + messageStream.str());
3083 return failure();
3084}
3085
3086/// Verify LLVM dialect attributes.
3087LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
3088 NamedAttribute attr) {
3089 // If the data layout attribute is present, it must use the LLVM data layout
3090 // syntax. Try parsing it and report errors in case of failure. Users of this
3091 // attribute may assume it is well-formed and can pass it to the (asserting)
3092 // llvm::DataLayout constructor.
3093 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
3094 return success();
3095 if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
3096 return verifyDataLayoutString(
3097 stringAttr.getValue(),
3098 [op](const Twine &message) { op->emitOpError() << message.str(); });
3099
3100 return op->emitOpError() << "expected '"
3101 << LLVM::LLVMDialect::getDataLayoutAttrName()
3102 << "' to be a string attributes";
3103}
3104
3105LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
3106 Type paramType,
3107 NamedAttribute paramAttr) {
3108 // LLVM attribute may be attached to a result of operation that has not been
3109 // converted to LLVM dialect yet, so the result may have a type with unknown
3110 // representation in LLVM dialect type space. In this case we cannot verify
3111 // whether the attribute may be
3112 bool verifyValueType = isCompatibleType(paramType);
3113 StringAttr name = paramAttr.getName();
3114
3115 auto checkUnitAttrType = [&]() -> LogicalResult {
3116 if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
3117 return op->emitError() << name << " should be a unit attribute";
3118 return success();
3119 };
3120 auto checkTypeAttrType = [&]() -> LogicalResult {
3121 if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
3122 return op->emitError() << name << " should be a type attribute";
3123 return success();
3124 };
3125 auto checkIntegerAttrType = [&]() -> LogicalResult {
3126 if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
3127 return op->emitError() << name << " should be an integer attribute";
3128 return success();
3129 };
3130 auto checkPointerType = [&]() -> LogicalResult {
3131 if (!llvm::isa<LLVMPointerType>(paramType))
3132 return op->emitError()
3133 << name << " attribute attached to non-pointer LLVM type";
3134 return success();
3135 };
3136 auto checkIntegerType = [&]() -> LogicalResult {
3137 if (!llvm::isa<IntegerType>(paramType))
3138 return op->emitError()
3139 << name << " attribute attached to non-integer LLVM type";
3140 return success();
3141 };
3142 auto checkPointerTypeMatches = [&]() -> LogicalResult {
3143 if (failed(checkPointerType()))
3144 return failure();
3145
3146 return success();
3147 };
3148
3149 // Check a unit attribute that is attached to a pointer value.
3150 if (name == LLVMDialect::getNoAliasAttrName() ||
3151 name == LLVMDialect::getReadonlyAttrName() ||
3152 name == LLVMDialect::getReadnoneAttrName() ||
3153 name == LLVMDialect::getWriteOnlyAttrName() ||
3154 name == LLVMDialect::getNestAttrName() ||
3155 name == LLVMDialect::getNoCaptureAttrName() ||
3156 name == LLVMDialect::getNoFreeAttrName() ||
3157 name == LLVMDialect::getNonNullAttrName()) {
3158 if (failed(checkUnitAttrType()))
3159 return failure();
3160 if (verifyValueType && failed(checkPointerType()))
3161 return failure();
3162 return success();
3163 }
3164
3165 // Check a type attribute that is attached to a pointer value.
3166 if (name == LLVMDialect::getStructRetAttrName() ||
3167 name == LLVMDialect::getByValAttrName() ||
3168 name == LLVMDialect::getByRefAttrName() ||
3169 name == LLVMDialect::getInAllocaAttrName() ||
3170 name == LLVMDialect::getPreallocatedAttrName()) {
3171 if (failed(checkTypeAttrType()))
3172 return failure();
3173 if (verifyValueType && failed(checkPointerTypeMatches()))
3174 return failure();
3175 return success();
3176 }
3177
3178 // Check a unit attribute that is attached to an integer value.
3179 if (name == LLVMDialect::getSExtAttrName() ||
3180 name == LLVMDialect::getZExtAttrName()) {
3181 if (failed(checkUnitAttrType()))
3182 return failure();
3183 if (verifyValueType && failed(checkIntegerType()))
3184 return failure();
3185 return success();
3186 }
3187
3188 // Check an integer attribute that is attached to a pointer value.
3189 if (name == LLVMDialect::getAlignAttrName() ||
3190 name == LLVMDialect::getDereferenceableAttrName() ||
3191 name == LLVMDialect::getDereferenceableOrNullAttrName() ||
3192 name == LLVMDialect::getStackAlignmentAttrName()) {
3193 if (failed(checkIntegerAttrType()))
3194 return failure();
3195 if (verifyValueType && failed(checkPointerType()))
3196 return failure();
3197 return success();
3198 }
3199
3200 // Check a unit attribute that can be attached to arbitrary types.
3201 if (name == LLVMDialect::getNoUndefAttrName() ||
3202 name == LLVMDialect::getInRegAttrName() ||
3203 name == LLVMDialect::getReturnedAttrName())
3204 return checkUnitAttrType();
3205
3206 return success();
3207}
3208
3209/// Verify LLVMIR function argument attributes.
3210LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
3211 unsigned regionIdx,
3212 unsigned argIdx,
3213 NamedAttribute argAttr) {
3214 auto funcOp = dyn_cast<FunctionOpInterface>(op);
3215 if (!funcOp)
3216 return success();
3217 Type argType = funcOp.getArgumentTypes()[argIdx];
3218
3219 return verifyParameterAttribute(op, argType, argAttr);
3220}
3221
3222LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
3223 unsigned regionIdx,
3224 unsigned resIdx,
3225 NamedAttribute resAttr) {
3226 auto funcOp = dyn_cast<FunctionOpInterface>(op);
3227 if (!funcOp)
3228 return success();
3229 Type resType = funcOp.getResultTypes()[resIdx];
3230
3231 // Check to see if this function has a void return with a result attribute
3232 // to it. It isn't clear what semantics we would assign to that.
3233 if (llvm::isa<LLVMVoidType>(resType))
3234 return op->emitError() << "cannot attach result attributes to functions "
3235 "with a void return";
3236
3237 // Check to see if this attribute is allowed as a result attribute. Only
3238 // explicitly forbidden LLVM attributes will cause an error.
3239 auto name = resAttr.getName();
3240 if (name == LLVMDialect::getAllocAlignAttrName() ||
3241 name == LLVMDialect::getAllocatedPointerAttrName() ||
3242 name == LLVMDialect::getByValAttrName() ||
3243 name == LLVMDialect::getByRefAttrName() ||
3244 name == LLVMDialect::getInAllocaAttrName() ||
3245 name == LLVMDialect::getNestAttrName() ||
3246 name == LLVMDialect::getNoCaptureAttrName() ||
3247 name == LLVMDialect::getNoFreeAttrName() ||
3248 name == LLVMDialect::getPreallocatedAttrName() ||
3249 name == LLVMDialect::getReadnoneAttrName() ||
3250 name == LLVMDialect::getReadonlyAttrName() ||
3251 name == LLVMDialect::getReturnedAttrName() ||
3252 name == LLVMDialect::getStackAlignmentAttrName() ||
3253 name == LLVMDialect::getStructRetAttrName() ||
3254 name == LLVMDialect::getWriteOnlyAttrName())
3255 return op->emitError() << name << " is not a valid result attribute";
3256 return verifyParameterAttribute(op, resType, resAttr);
3257}
3258
3259Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
3260 Type type, Location loc) {
3261 return LLVM::ConstantOp::materialize(builder, value, type, loc);
3262}
3263
3264//===----------------------------------------------------------------------===//
3265// Utility functions.
3266//===----------------------------------------------------------------------===//
3267
3268Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
3269 StringRef name, StringRef value,
3270 LLVM::Linkage linkage) {
3271 assert(builder.getInsertionBlock() &&
3272 builder.getInsertionBlock()->getParentOp() &&
3273 "expected builder to point to a block constrained in an op");
3274 auto module =
3275 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
3276 assert(module && "builder points to an op outside of a module");
3277
3278 // Create the global at the entry of the module.
3279 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
3280 MLIRContext *ctx = builder.getContext();
3281 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
3282 auto global = moduleBuilder.create<LLVM::GlobalOp>(
3283 loc, type, /*isConstant=*/true, linkage, name,
3284 builder.getStringAttr(value), /*alignment=*/0);
3285
3286 LLVMPointerType ptrType = LLVMPointerType::get(ctx);
3287 // Get the pointer to the first character in the global string.
3288 Value globalPtr =
3289 builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr());
3290 return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr,
3291 ArrayRef<GEPArg>{0, 0});
3292}
3293
3294bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
3295 return op->hasTrait<OpTrait::SymbolTable>() &&
3296 op->hasTrait<OpTrait::IsIsolatedFromAbove>();
3297}
3298

source code of mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp