1 | //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the types and operation details for the LLVM IR dialect in |
10 | // MLIR, and the LLVM IR dialect. It also registers the dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | #include "LLVMInlining.h" |
16 | #include "TypeDetail.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" |
19 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/BuiltinOps.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/DialectImplementation.h" |
24 | #include "mlir/IR/MLIRContext.h" |
25 | #include "mlir/IR/Matchers.h" |
26 | #include "mlir/Interfaces/FunctionImplementation.h" |
27 | |
28 | #include "llvm/ADT/SCCIterator.h" |
29 | #include "llvm/ADT/TypeSwitch.h" |
30 | #include "llvm/AsmParser/Parser.h" |
31 | #include "llvm/Bitcode/BitcodeReader.h" |
32 | #include "llvm/Bitcode/BitcodeWriter.h" |
33 | #include "llvm/IR/Attributes.h" |
34 | #include "llvm/IR/Function.h" |
35 | #include "llvm/IR/Type.h" |
36 | #include "llvm/Support/Error.h" |
37 | #include "llvm/Support/Mutex.h" |
38 | #include "llvm/Support/SourceMgr.h" |
39 | |
40 | #include <numeric> |
41 | #include <optional> |
42 | |
43 | using namespace mlir; |
44 | using namespace mlir::LLVM; |
45 | using mlir::LLVM::cconv::getMaxEnumValForCConv; |
46 | using mlir::LLVM::linkage::getMaxEnumValForLinkage; |
47 | |
48 | #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" |
49 | |
50 | //===----------------------------------------------------------------------===// |
51 | // Property Helpers |
52 | //===----------------------------------------------------------------------===// |
53 | |
54 | //===----------------------------------------------------------------------===// |
55 | // IntegerOverflowFlags |
56 | |
57 | namespace mlir { |
58 | static Attribute convertToAttribute(MLIRContext *ctx, |
59 | IntegerOverflowFlags flags) { |
60 | return IntegerOverflowFlagsAttr::get(ctx, flags); |
61 | } |
62 | |
63 | static LogicalResult |
64 | convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr, |
65 | function_ref<InFlightDiagnostic()> emitError) { |
66 | auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr); |
67 | if (!flagsAttr) { |
68 | return emitError() << "expected 'overflowFlags' attribute to be an " |
69 | "IntegerOverflowFlagsAttr, but got " |
70 | << attr; |
71 | } |
72 | flags = flagsAttr.getValue(); |
73 | return success(); |
74 | } |
75 | } // namespace mlir |
76 | |
77 | static ParseResult parseOverflowFlags(AsmParser &p, |
78 | IntegerOverflowFlags &flags) { |
79 | if (failed(result: p.parseOptionalKeyword(keyword: "overflow" ))) { |
80 | flags = IntegerOverflowFlags::none; |
81 | return success(); |
82 | } |
83 | if (p.parseLess()) |
84 | return failure(); |
85 | do { |
86 | StringRef kw; |
87 | SMLoc loc = p.getCurrentLocation(); |
88 | if (p.parseKeyword(keyword: &kw)) |
89 | return failure(); |
90 | std::optional<IntegerOverflowFlags> flag = |
91 | symbolizeIntegerOverflowFlags(kw); |
92 | if (!flag) |
93 | return p.emitError(loc, |
94 | message: "invalid overflow flag: expected nsw, nuw, or none" ); |
95 | flags = flags | *flag; |
96 | } while (succeeded(result: p.parseOptionalComma())); |
97 | return p.parseGreater(); |
98 | } |
99 | |
100 | static void printOverflowFlags(AsmPrinter &p, Operation *op, |
101 | IntegerOverflowFlags flags) { |
102 | if (flags == IntegerOverflowFlags::none) |
103 | return; |
104 | p << " overflow<" ; |
105 | SmallVector<StringRef, 2> strs; |
106 | if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) |
107 | strs.push_back(Elt: "nsw" ); |
108 | if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) |
109 | strs.push_back(Elt: "nuw" ); |
110 | llvm::interleaveComma(c: strs, os&: p); |
111 | p << ">" ; |
112 | } |
113 | |
114 | //===----------------------------------------------------------------------===// |
115 | // Attribute Helpers |
116 | //===----------------------------------------------------------------------===// |
117 | |
118 | static constexpr const char kElemTypeAttrName[] = "elem_type" ; |
119 | |
120 | static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { |
121 | SmallVector<NamedAttribute, 8> filteredAttrs( |
122 | llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) { |
123 | if (attr.getName() == "fastmathFlags" ) { |
124 | auto defAttr = |
125 | FastmathFlagsAttr::get(attr.getValue().getContext(), {}); |
126 | return defAttr != attr.getValue(); |
127 | } |
128 | return true; |
129 | })); |
130 | return filteredAttrs; |
131 | } |
132 | |
133 | static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, |
134 | NamedAttrList &result) { |
135 | return parser.parseOptionalAttrDict(result); |
136 | } |
137 | |
138 | static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, |
139 | DictionaryAttr attrs) { |
140 | auto filteredAttrs = processFMFAttr(attrs.getValue()); |
141 | if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) { |
142 | printer.printOptionalAttrDict( |
143 | attrs: filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()}); |
144 | } else { |
145 | printer.printOptionalAttrDict(attrs: filteredAttrs); |
146 | } |
147 | } |
148 | |
149 | /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and |
150 | /// fully defined llvm.func. |
151 | static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, |
152 | Operation *op, |
153 | SymbolTableCollection &symbolTable) { |
154 | StringRef name = symbol.getValue(); |
155 | auto func = |
156 | symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); |
157 | if (!func) |
158 | return op->emitOpError(message: "'" ) |
159 | << name << "' does not reference a valid LLVM function" ; |
160 | if (func.isExternal()) |
161 | return op->emitOpError(message: "'" ) << name << "' does not have a definition" ; |
162 | return success(); |
163 | } |
164 | |
165 | /// Returns a boolean type that has the same shape as `type`. It supports both |
166 | /// fixed size vectors as well as scalable vectors. |
167 | static Type getI1SameShape(Type type) { |
168 | Type i1Type = IntegerType::get(type.getContext(), 1); |
169 | if (LLVM::isCompatibleVectorType(type)) |
170 | return LLVM::getVectorType(elementType: i1Type, numElements: LLVM::getVectorNumElements(type)); |
171 | return i1Type; |
172 | } |
173 | |
174 | // Parses one of the keywords provided in the list `keywords` and returns the |
175 | // position of the parsed keyword in the list. If none of the keywords from the |
176 | // list is parsed, returns -1. |
177 | static int parseOptionalKeywordAlternative(OpAsmParser &parser, |
178 | ArrayRef<StringRef> keywords) { |
179 | for (const auto &en : llvm::enumerate(First&: keywords)) { |
180 | if (succeeded(result: parser.parseOptionalKeyword(keyword: en.value()))) |
181 | return en.index(); |
182 | } |
183 | return -1; |
184 | } |
185 | |
186 | namespace { |
187 | template <typename Ty> |
188 | struct EnumTraits {}; |
189 | |
190 | #define REGISTER_ENUM_TYPE(Ty) \ |
191 | template <> \ |
192 | struct EnumTraits<Ty> { \ |
193 | static StringRef stringify(Ty value) { return stringify##Ty(value); } \ |
194 | static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ |
195 | } |
196 | |
197 | REGISTER_ENUM_TYPE(Linkage); |
198 | REGISTER_ENUM_TYPE(UnnamedAddr); |
199 | REGISTER_ENUM_TYPE(CConv); |
200 | REGISTER_ENUM_TYPE(Visibility); |
201 | } // namespace |
202 | |
203 | /// Parse an enum from the keyword, or default to the provided default value. |
204 | /// The return type is the enum type by default, unless overridden with the |
205 | /// second template argument. |
206 | template <typename EnumTy, typename RetTy = EnumTy> |
207 | static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, |
208 | OperationState &result, |
209 | EnumTy defaultValue) { |
210 | SmallVector<StringRef, 10> names; |
211 | for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) |
212 | names.push_back(Elt: EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); |
213 | |
214 | int index = parseOptionalKeywordAlternative(parser, keywords: names); |
215 | if (index == -1) |
216 | return static_cast<RetTy>(defaultValue); |
217 | return static_cast<RetTy>(index); |
218 | } |
219 | |
220 | //===----------------------------------------------------------------------===// |
221 | // Printing, parsing, folding and builder for LLVM::CmpOp. |
222 | //===----------------------------------------------------------------------===// |
223 | |
224 | void ICmpOp::print(OpAsmPrinter &p) { |
225 | p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) |
226 | << ", " << getOperand(1); |
227 | p.printOptionalAttrDict((*this)->getAttrs(), {"predicate" }); |
228 | p << " : " << getLhs().getType(); |
229 | } |
230 | |
231 | void FCmpOp::print(OpAsmPrinter &p) { |
232 | p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) |
233 | << ", " << getOperand(1); |
234 | p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate" }); |
235 | p << " : " << getLhs().getType(); |
236 | } |
237 | |
238 | // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use |
239 | // attribute-dict? `:` type |
240 | // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use |
241 | // attribute-dict? `:` type |
242 | template <typename CmpPredicateType> |
243 | static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { |
244 | StringAttr predicateAttr; |
245 | OpAsmParser::UnresolvedOperand lhs, rhs; |
246 | Type type; |
247 | SMLoc predicateLoc, trailingTypeLoc; |
248 | if (parser.getCurrentLocation(loc: &predicateLoc) || |
249 | parser.parseAttribute(predicateAttr, "predicate" , result.attributes) || |
250 | parser.parseOperand(result&: lhs) || parser.parseComma() || |
251 | parser.parseOperand(result&: rhs) || |
252 | parser.parseOptionalAttrDict(result&: result.attributes) || parser.parseColon() || |
253 | parser.getCurrentLocation(loc: &trailingTypeLoc) || parser.parseType(result&: type) || |
254 | parser.resolveOperand(operand: lhs, type, result&: result.operands) || |
255 | parser.resolveOperand(operand: rhs, type, result&: result.operands)) |
256 | return failure(); |
257 | |
258 | // Replace the string attribute `predicate` with an integer attribute. |
259 | int64_t predicateValue = 0; |
260 | if (std::is_same<CmpPredicateType, ICmpPredicate>()) { |
261 | std::optional<ICmpPredicate> predicate = |
262 | symbolizeICmpPredicate(predicateAttr.getValue()); |
263 | if (!predicate) |
264 | return parser.emitError(loc: predicateLoc) |
265 | << "'" << predicateAttr.getValue() |
266 | << "' is an incorrect value of the 'predicate' attribute" ; |
267 | predicateValue = static_cast<int64_t>(*predicate); |
268 | } else { |
269 | std::optional<FCmpPredicate> predicate = |
270 | symbolizeFCmpPredicate(predicateAttr.getValue()); |
271 | if (!predicate) |
272 | return parser.emitError(loc: predicateLoc) |
273 | << "'" << predicateAttr.getValue() |
274 | << "' is an incorrect value of the 'predicate' attribute" ; |
275 | predicateValue = static_cast<int64_t>(*predicate); |
276 | } |
277 | |
278 | result.attributes.set("predicate" , |
279 | parser.getBuilder().getI64IntegerAttr(predicateValue)); |
280 | |
281 | // The result type is either i1 or a vector type <? x i1> if the inputs are |
282 | // vectors. |
283 | if (!isCompatibleType(type)) |
284 | return parser.emitError(loc: trailingTypeLoc, |
285 | message: "expected LLVM dialect-compatible type" ); |
286 | result.addTypes(newTypes: getI1SameShape(type)); |
287 | return success(); |
288 | } |
289 | |
290 | ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { |
291 | return parseCmpOp<ICmpPredicate>(parser, result); |
292 | } |
293 | |
294 | ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { |
295 | return parseCmpOp<FCmpPredicate>(parser, result); |
296 | } |
297 | |
298 | /// Returns a scalar or vector boolean attribute of the given type. |
299 | static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { |
300 | auto boolAttr = BoolAttr::get(context: ctx, value); |
301 | ShapedType shapedType = dyn_cast<ShapedType>(type); |
302 | if (!shapedType) |
303 | return boolAttr; |
304 | return DenseElementsAttr::get(shapedType, boolAttr); |
305 | } |
306 | |
307 | OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) { |
308 | if (getPredicate() != ICmpPredicate::eq && |
309 | getPredicate() != ICmpPredicate::ne) |
310 | return {}; |
311 | |
312 | // cmpi(eq/ne, x, x) -> true/false |
313 | if (getLhs() == getRhs()) |
314 | return getBoolAttribute(getType(), getContext(), |
315 | getPredicate() == ICmpPredicate::eq); |
316 | |
317 | // cmpi(eq/ne, alloca, null) -> false/true |
318 | if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>()) |
319 | return getBoolAttribute(getType(), getContext(), |
320 | getPredicate() == ICmpPredicate::ne); |
321 | |
322 | // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null) |
323 | if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) { |
324 | Value lhs = getLhs(); |
325 | Value rhs = getRhs(); |
326 | getLhsMutable().assign(rhs); |
327 | getRhsMutable().assign(lhs); |
328 | return getResult(); |
329 | } |
330 | |
331 | return {}; |
332 | } |
333 | |
334 | //===----------------------------------------------------------------------===// |
335 | // Printing, parsing and verification for LLVM::AllocaOp. |
336 | //===----------------------------------------------------------------------===// |
337 | |
338 | void AllocaOp::print(OpAsmPrinter &p) { |
339 | auto funcTy = |
340 | FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); |
341 | |
342 | if (getInalloca()) |
343 | p << " inalloca" ; |
344 | |
345 | p << ' ' << getArraySize() << " x " << getElemType(); |
346 | if (getAlignment() && *getAlignment() != 0) |
347 | p.printOptionalAttrDict((*this)->getAttrs(), |
348 | {kElemTypeAttrName, getInallocaAttrName()}); |
349 | else |
350 | p.printOptionalAttrDict( |
351 | (*this)->getAttrs(), |
352 | {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()}); |
353 | p << " : " << funcTy; |
354 | } |
355 | |
356 | // <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type |
357 | // attribute-dict? `:` type `,` type |
358 | ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { |
359 | OpAsmParser::UnresolvedOperand arraySize; |
360 | Type type, elemType; |
361 | SMLoc trailingTypeLoc; |
362 | |
363 | if (succeeded(parser.parseOptionalKeyword("inalloca" ))) |
364 | result.addAttribute(getInallocaAttrName(result.name), |
365 | UnitAttr::get(parser.getContext())); |
366 | |
367 | if (parser.parseOperand(arraySize) || parser.parseKeyword("x" ) || |
368 | parser.parseType(elemType) || |
369 | parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
370 | parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) |
371 | return failure(); |
372 | |
373 | std::optional<NamedAttribute> alignmentAttr = |
374 | result.attributes.getNamed("alignment" ); |
375 | if (alignmentAttr.has_value()) { |
376 | auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue()); |
377 | if (!alignmentInt) |
378 | return parser.emitError(parser.getNameLoc(), |
379 | "expected integer alignment" ); |
380 | if (alignmentInt.getValue().isZero()) |
381 | result.attributes.erase("alignment" ); |
382 | } |
383 | |
384 | // Extract the result type from the trailing function type. |
385 | auto funcType = llvm::dyn_cast<FunctionType>(type); |
386 | if (!funcType || funcType.getNumInputs() != 1 || |
387 | funcType.getNumResults() != 1) |
388 | return parser.emitError( |
389 | trailingTypeLoc, |
390 | "expected trailing function type with one argument and one result" ); |
391 | |
392 | if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) |
393 | return failure(); |
394 | |
395 | Type resultType = funcType.getResult(0); |
396 | if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType)) |
397 | result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); |
398 | |
399 | result.addTypes({funcType.getResult(0)}); |
400 | return success(); |
401 | } |
402 | |
403 | LogicalResult AllocaOp::verify() { |
404 | // Only certain target extension types can be used in 'alloca'. |
405 | if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType()); |
406 | targetExtType && !targetExtType.supportsMemOps()) |
407 | return emitOpError() |
408 | << "this target extension type cannot be used in alloca" ; |
409 | |
410 | return success(); |
411 | } |
412 | |
413 | Type AllocaOp::getResultPtrElementType() { return getElemType(); } |
414 | |
415 | //===----------------------------------------------------------------------===// |
416 | // LLVM::BrOp |
417 | //===----------------------------------------------------------------------===// |
418 | |
419 | SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { |
420 | assert(index == 0 && "invalid successor index" ); |
421 | return SuccessorOperands(getDestOperandsMutable()); |
422 | } |
423 | |
424 | //===----------------------------------------------------------------------===// |
425 | // LLVM::CondBrOp |
426 | //===----------------------------------------------------------------------===// |
427 | |
428 | SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { |
429 | assert(index < getNumSuccessors() && "invalid successor index" ); |
430 | return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() |
431 | : getFalseDestOperandsMutable()); |
432 | } |
433 | |
434 | void CondBrOp::build(OpBuilder &builder, OperationState &result, |
435 | Value condition, Block *trueDest, ValueRange trueOperands, |
436 | Block *falseDest, ValueRange falseOperands, |
437 | std::optional<std::pair<uint32_t, uint32_t>> weights) { |
438 | DenseI32ArrayAttr weightsAttr; |
439 | if (weights) |
440 | weightsAttr = |
441 | builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first), |
442 | static_cast<int32_t>(weights->second)}); |
443 | |
444 | build(builder, result, condition, trueOperands, falseOperands, weightsAttr, |
445 | /*loop_annotation=*/{}, trueDest, falseDest); |
446 | } |
447 | |
448 | //===----------------------------------------------------------------------===// |
449 | // LLVM::SwitchOp |
450 | //===----------------------------------------------------------------------===// |
451 | |
452 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
453 | Block *defaultDestination, ValueRange defaultOperands, |
454 | DenseIntElementsAttr caseValues, |
455 | BlockRange caseDestinations, |
456 | ArrayRef<ValueRange> caseOperands, |
457 | ArrayRef<int32_t> branchWeights) { |
458 | DenseI32ArrayAttr weightsAttr; |
459 | if (!branchWeights.empty()) |
460 | weightsAttr = builder.getDenseI32ArrayAttr(branchWeights); |
461 | |
462 | build(builder, result, value, defaultOperands, caseOperands, caseValues, |
463 | weightsAttr, defaultDestination, caseDestinations); |
464 | } |
465 | |
466 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
467 | Block *defaultDestination, ValueRange defaultOperands, |
468 | ArrayRef<APInt> caseValues, BlockRange caseDestinations, |
469 | ArrayRef<ValueRange> caseOperands, |
470 | ArrayRef<int32_t> branchWeights) { |
471 | DenseIntElementsAttr caseValuesAttr; |
472 | if (!caseValues.empty()) { |
473 | ShapedType caseValueType = VectorType::get( |
474 | static_cast<int64_t>(caseValues.size()), value.getType()); |
475 | caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); |
476 | } |
477 | |
478 | build(builder, result, value, defaultDestination, defaultOperands, |
479 | caseValuesAttr, caseDestinations, caseOperands, branchWeights); |
480 | } |
481 | |
482 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
483 | Block *defaultDestination, ValueRange defaultOperands, |
484 | ArrayRef<int32_t> caseValues, BlockRange caseDestinations, |
485 | ArrayRef<ValueRange> caseOperands, |
486 | ArrayRef<int32_t> branchWeights) { |
487 | DenseIntElementsAttr caseValuesAttr; |
488 | if (!caseValues.empty()) { |
489 | ShapedType caseValueType = VectorType::get( |
490 | static_cast<int64_t>(caseValues.size()), value.getType()); |
491 | caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); |
492 | } |
493 | |
494 | build(builder, result, value, defaultDestination, defaultOperands, |
495 | caseValuesAttr, caseDestinations, caseOperands, branchWeights); |
496 | } |
497 | |
498 | /// <cases> ::= `[` (case (`,` case )* )? `]` |
499 | /// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? |
500 | static ParseResult parseSwitchOpCases( |
501 | OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, |
502 | SmallVectorImpl<Block *> &caseDestinations, |
503 | SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, |
504 | SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { |
505 | if (failed(result: parser.parseLSquare())) |
506 | return failure(); |
507 | if (succeeded(result: parser.parseOptionalRSquare())) |
508 | return success(); |
509 | SmallVector<APInt> values; |
510 | unsigned bitWidth = flagType.getIntOrFloatBitWidth(); |
511 | auto parseCase = [&]() { |
512 | int64_t value = 0; |
513 | if (failed(result: parser.parseInteger(result&: value))) |
514 | return failure(); |
515 | values.push_back(Elt: APInt(bitWidth, value)); |
516 | |
517 | Block *destination; |
518 | SmallVector<OpAsmParser::UnresolvedOperand> operands; |
519 | SmallVector<Type> operandTypes; |
520 | if (parser.parseColon() || parser.parseSuccessor(dest&: destination)) |
521 | return failure(); |
522 | if (!parser.parseOptionalLParen()) { |
523 | if (parser.parseOperandList(result&: operands, delimiter: OpAsmParser::Delimiter::None, |
524 | /*allowResultNumber=*/false) || |
525 | parser.parseColonTypeList(result&: operandTypes) || parser.parseRParen()) |
526 | return failure(); |
527 | } |
528 | caseDestinations.push_back(Elt: destination); |
529 | caseOperands.emplace_back(Args&: operands); |
530 | caseOperandTypes.emplace_back(Args&: operandTypes); |
531 | return success(); |
532 | }; |
533 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: parseCase))) |
534 | return failure(); |
535 | |
536 | ShapedType caseValueType = |
537 | VectorType::get(static_cast<int64_t>(values.size()), flagType); |
538 | caseValues = DenseIntElementsAttr::get(caseValueType, values); |
539 | return parser.parseRSquare(); |
540 | } |
541 | |
542 | static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, |
543 | DenseIntElementsAttr caseValues, |
544 | SuccessorRange caseDestinations, |
545 | OperandRangeRange caseOperands, |
546 | const TypeRangeRange &caseOperandTypes) { |
547 | p << '['; |
548 | p.printNewline(); |
549 | if (!caseValues) { |
550 | p << ']'; |
551 | return; |
552 | } |
553 | |
554 | size_t index = 0; |
555 | llvm::interleave( |
556 | c: llvm::zip(t&: caseValues, u&: caseDestinations), |
557 | each_fn: [&](auto i) { |
558 | p << " " ; |
559 | p << std::get<0>(i).getLimitedValue(); |
560 | p << ": " ; |
561 | p.printSuccessorAndUseList(successor: std::get<1>(i), succOperands: caseOperands[index++]); |
562 | }, |
563 | between_fn: [&] { |
564 | p << ','; |
565 | p.printNewline(); |
566 | }); |
567 | p.printNewline(); |
568 | p << ']'; |
569 | } |
570 | |
571 | LogicalResult SwitchOp::verify() { |
572 | if ((!getCaseValues() && !getCaseDestinations().empty()) || |
573 | (getCaseValues() && |
574 | getCaseValues()->size() != |
575 | static_cast<int64_t>(getCaseDestinations().size()))) |
576 | return emitOpError("expects number of case values to match number of " |
577 | "case destinations" ); |
578 | if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) |
579 | return emitError("expects number of branch weights to match number of " |
580 | "successors: " ) |
581 | << getBranchWeights()->size() << " vs " << getNumSuccessors(); |
582 | if (getCaseValues() && |
583 | getValue().getType() != getCaseValues()->getElementType()) |
584 | return emitError("expects case value type to match condition value type" ); |
585 | return success(); |
586 | } |
587 | |
588 | SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { |
589 | assert(index < getNumSuccessors() && "invalid successor index" ); |
590 | return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() |
591 | : getCaseOperandsMutable(index - 1)); |
592 | } |
593 | |
594 | //===----------------------------------------------------------------------===// |
595 | // Code for LLVM::GEPOp. |
596 | //===----------------------------------------------------------------------===// |
597 | |
598 | constexpr int32_t GEPOp::kDynamicIndex; |
599 | |
600 | GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() { |
601 | return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(), |
602 | getDynamicIndices()); |
603 | } |
604 | |
605 | /// Returns the elemental type of any LLVM-compatible vector type or self. |
606 | static Type (Type type) { |
607 | if (auto vectorType = llvm::dyn_cast<VectorType>(type)) |
608 | return vectorType.getElementType(); |
609 | if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type)) |
610 | return scalableVectorType.getElementType(); |
611 | if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type)) |
612 | return fixedVectorType.getElementType(); |
613 | return type; |
614 | } |
615 | |
616 | /// Destructures the 'indices' parameter into 'rawConstantIndices' and |
617 | /// 'dynamicIndices', encoding the former in the process. In the process, |
618 | /// dynamic indices which are used to index into a structure type are converted |
619 | /// to constant indices when possible. To do this, the GEPs element type should |
620 | /// be passed as first parameter. |
621 | static void destructureIndices(Type currType, ArrayRef<GEPArg> indices, |
622 | SmallVectorImpl<int32_t> &rawConstantIndices, |
623 | SmallVectorImpl<Value> &dynamicIndices) { |
624 | for (const GEPArg &iter : indices) { |
625 | // If the thing we are currently indexing into is a struct we must turn |
626 | // any integer constants into constant indices. If this is not possible |
627 | // we don't do anything here. The verifier will catch it and emit a proper |
628 | // error. All other canonicalization is done in the fold method. |
629 | bool requiresConst = !rawConstantIndices.empty() && |
630 | isa_and_nonnull<LLVMStructType>(Val: currType); |
631 | if (Value val = llvm::dyn_cast_if_present<Value>(Val: iter)) { |
632 | APInt intC; |
633 | if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) && |
634 | intC.isSignedIntN(kGEPConstantBitWidth)) { |
635 | rawConstantIndices.push_back(Elt: intC.getSExtValue()); |
636 | } else { |
637 | rawConstantIndices.push_back(GEPOp::kDynamicIndex); |
638 | dynamicIndices.push_back(Elt: val); |
639 | } |
640 | } else { |
641 | rawConstantIndices.push_back(Elt: iter.get<GEPConstantIndex>()); |
642 | } |
643 | |
644 | // Skip for very first iteration of this loop. First index does not index |
645 | // within the aggregates, but is just a pointer offset. |
646 | if (rawConstantIndices.size() == 1 || !currType) |
647 | continue; |
648 | |
649 | currType = |
650 | TypeSwitch<Type, Type>(currType) |
651 | .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, |
652 | LLVMArrayType>([](auto containerType) { |
653 | return containerType.getElementType(); |
654 | }) |
655 | .Case([&](LLVMStructType structType) -> Type { |
656 | int64_t memberIndex = rawConstantIndices.back(); |
657 | if (memberIndex >= 0 && static_cast<size_t>(memberIndex) < |
658 | structType.getBody().size()) |
659 | return structType.getBody()[memberIndex]; |
660 | return nullptr; |
661 | }) |
662 | .Default(Type(nullptr)); |
663 | } |
664 | } |
665 | |
666 | void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, |
667 | Type elementType, Value basePtr, ArrayRef<GEPArg> indices, |
668 | bool inbounds, ArrayRef<NamedAttribute> attributes) { |
669 | SmallVector<int32_t> rawConstantIndices; |
670 | SmallVector<Value> dynamicIndices; |
671 | destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices); |
672 | |
673 | result.addTypes(resultType); |
674 | result.addAttributes(attributes); |
675 | result.addAttribute(getRawConstantIndicesAttrName(result.name), |
676 | builder.getDenseI32ArrayAttr(rawConstantIndices)); |
677 | if (inbounds) { |
678 | result.addAttribute(getInboundsAttrName(result.name), |
679 | builder.getUnitAttr()); |
680 | } |
681 | result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); |
682 | result.addOperands(basePtr); |
683 | result.addOperands(dynamicIndices); |
684 | } |
685 | |
686 | void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, |
687 | Type elementType, Value basePtr, ValueRange indices, |
688 | bool inbounds, ArrayRef<NamedAttribute> attributes) { |
689 | build(builder, result, resultType, elementType, basePtr, |
690 | SmallVector<GEPArg>(indices), inbounds, attributes); |
691 | } |
692 | |
693 | static ParseResult |
694 | parseGEPIndices(OpAsmParser &parser, |
695 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, |
696 | DenseI32ArrayAttr &rawConstantIndices) { |
697 | SmallVector<int32_t> constantIndices; |
698 | |
699 | auto idxParser = [&]() -> ParseResult { |
700 | int32_t constantIndex; |
701 | OptionalParseResult parsedInteger = |
702 | parser.parseOptionalInteger(result&: constantIndex); |
703 | if (parsedInteger.has_value()) { |
704 | if (failed(result: parsedInteger.value())) |
705 | return failure(); |
706 | constantIndices.push_back(Elt: constantIndex); |
707 | return success(); |
708 | } |
709 | |
710 | constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); |
711 | return parser.parseOperand(result&: indices.emplace_back()); |
712 | }; |
713 | if (parser.parseCommaSeparatedList(parseElementFn: idxParser)) |
714 | return failure(); |
715 | |
716 | rawConstantIndices = |
717 | DenseI32ArrayAttr::get(parser.getContext(), constantIndices); |
718 | return success(); |
719 | } |
720 | |
721 | static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, |
722 | OperandRange indices, |
723 | DenseI32ArrayAttr rawConstantIndices) { |
724 | llvm::interleaveComma( |
725 | c: GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), os&: printer, |
726 | each_fn: [&](PointerUnion<IntegerAttr, Value> cst) { |
727 | if (Value val = llvm::dyn_cast_if_present<Value>(Val&: cst)) |
728 | printer.printOperand(value: val); |
729 | else |
730 | printer << cst.get<IntegerAttr>().getInt(); |
731 | }); |
732 | } |
733 | |
734 | /// For the given `indices`, check if they comply with `baseGEPType`, |
735 | /// especially check against LLVMStructTypes nested within. |
736 | static LogicalResult |
737 | verifyStructIndices(Type baseGEPType, unsigned indexPos, |
738 | GEPIndicesAdaptor<ValueRange> indices, |
739 | function_ref<InFlightDiagnostic()> emitOpError) { |
740 | if (indexPos >= indices.size()) |
741 | // Stop searching |
742 | return success(); |
743 | |
744 | return TypeSwitch<Type, LogicalResult>(baseGEPType) |
745 | .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult { |
746 | if (!indices[indexPos].is<IntegerAttr>()) |
747 | return emitOpError() << "expected index " << indexPos |
748 | << " indexing a struct to be constant" ; |
749 | |
750 | int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt(); |
751 | ArrayRef<Type> elementTypes = structType.getBody(); |
752 | if (gepIndex < 0 || |
753 | static_cast<size_t>(gepIndex) >= elementTypes.size()) |
754 | return emitOpError() << "index " << indexPos |
755 | << " indexing a struct is out of bounds" ; |
756 | |
757 | // Instead of recursively going into every children types, we only |
758 | // dive into the one indexed by gepIndex. |
759 | return verifyStructIndices(elementTypes[gepIndex], indexPos + 1, |
760 | indices, emitOpError); |
761 | }) |
762 | .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, |
763 | LLVMArrayType>([&](auto containerType) -> LogicalResult { |
764 | return verifyStructIndices(containerType.getElementType(), indexPos + 1, |
765 | indices, emitOpError); |
766 | }) |
767 | .Default([&](auto otherType) -> LogicalResult { |
768 | return emitOpError() |
769 | << "type " << otherType << " cannot be indexed (index #" |
770 | << indexPos << ")" ; |
771 | }); |
772 | } |
773 | |
774 | /// Driver function around `verifyStructIndices`. |
775 | static LogicalResult |
776 | verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices, |
777 | function_ref<InFlightDiagnostic()> emitOpError) { |
778 | return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError); |
779 | } |
780 | |
781 | LogicalResult LLVM::GEPOp::verify() { |
782 | if (static_cast<size_t>( |
783 | llvm::count(getRawConstantIndices(), kDynamicIndex)) != |
784 | getDynamicIndices().size()) |
785 | return emitOpError("expected as many dynamic indices as specified in '" ) |
786 | << getRawConstantIndicesAttrName().getValue() << "'" ; |
787 | |
788 | return verifyStructIndices(getElemType(), getIndices(), |
789 | [&] { return emitOpError(); }); |
790 | } |
791 | |
792 | Type GEPOp::getResultPtrElementType() { |
793 | // Set the initial type currently being used for indexing. This will be |
794 | // updated as the indices get walked over. |
795 | Type selectedType = getElemType(); |
796 | |
797 | // Follow the indexed elements in the gep. |
798 | auto indices = getIndices(); |
799 | for (GEPIndicesAdaptor<ValueRange>::value_type index : |
800 | llvm::drop_begin(indices)) { |
801 | // GEPs can only index into aggregates which can be structs or arrays. |
802 | |
803 | // The resulting type if indexing into an array type is always the element |
804 | // type, regardless of index. |
805 | if (auto arrayType = dyn_cast<LLVMArrayType>(selectedType)) { |
806 | selectedType = arrayType.getElementType(); |
807 | continue; |
808 | } |
809 | |
810 | // The GEP verifier ensures that any index into structs are static and |
811 | // that they refer to a field within the struct. |
812 | selectedType = cast<DestructurableTypeInterface>(selectedType) |
813 | .getTypeAtIndex(cast<IntegerAttr>(index)); |
814 | } |
815 | |
816 | // When there are no more indices, the type currently being used for indexing |
817 | // is the type of the value pointed at by the returned indexed pointer. |
818 | return selectedType; |
819 | } |
820 | |
821 | //===----------------------------------------------------------------------===// |
822 | // LoadOp |
823 | //===----------------------------------------------------------------------===// |
824 | |
825 | void LoadOp::getEffects( |
826 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
827 | &effects) { |
828 | effects.emplace_back(MemoryEffects::Read::get(), getAddr()); |
829 | // Volatile operations can have target-specific read-write effects on |
830 | // memory besides the one referred to by the pointer operand. |
831 | // Similarly, atomic operations that are monotonic or stricter cause |
832 | // synchronization that from a language point-of-view, are arbitrary |
833 | // read-writes into memory. |
834 | if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && |
835 | getOrdering() != AtomicOrdering::unordered)) { |
836 | effects.emplace_back(MemoryEffects::Write::get()); |
837 | effects.emplace_back(MemoryEffects::Read::get()); |
838 | } |
839 | } |
840 | |
841 | /// Returns true if the given type is supported by atomic operations. All |
842 | /// integer and float types with limited bit width are supported. Additionally, |
843 | /// depending on the operation pointers may be supported as well. |
844 | static bool isTypeCompatibleWithAtomicOp(Type type, bool isPointerTypeAllowed) { |
845 | if (llvm::isa<LLVMPointerType>(type)) |
846 | return isPointerTypeAllowed; |
847 | |
848 | std::optional<unsigned> bitWidth; |
849 | if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) { |
850 | if (!isCompatibleFloatingPointType(type)) |
851 | return false; |
852 | bitWidth = floatType.getWidth(); |
853 | } |
854 | if (auto integerType = llvm::dyn_cast<IntegerType>(type)) |
855 | bitWidth = integerType.getWidth(); |
856 | // The type is neither an integer, float, or pointer type. |
857 | if (!bitWidth) |
858 | return false; |
859 | return *bitWidth == 8 || *bitWidth == 16 || *bitWidth == 32 || |
860 | *bitWidth == 64; |
861 | } |
862 | |
863 | /// Verifies the attributes and the type of atomic memory access operations. |
864 | template <typename OpTy> |
865 | LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, |
866 | ArrayRef<AtomicOrdering> unsupportedOrderings) { |
867 | if (memOp.getOrdering() != AtomicOrdering::not_atomic) { |
868 | if (!isTypeCompatibleWithAtomicOp(type: valueType, |
869 | /*isPointerTypeAllowed=*/true)) |
870 | return memOp.emitOpError("unsupported type " ) |
871 | << valueType << " for atomic access" ; |
872 | if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering())) |
873 | return memOp.emitOpError("unsupported ordering '" ) |
874 | << stringifyAtomicOrdering(memOp.getOrdering()) << "'" ; |
875 | if (!memOp.getAlignment()) |
876 | return memOp.emitOpError("expected alignment for atomic access" ); |
877 | return success(); |
878 | } |
879 | if (memOp.getSyncscope()) |
880 | return memOp.emitOpError( |
881 | "expected syncscope to be null for non-atomic access" ); |
882 | return success(); |
883 | } |
884 | |
885 | LogicalResult LoadOp::verify() { |
886 | Type valueType = getResult().getType(); |
887 | return verifyAtomicMemOp(*this, valueType, |
888 | {AtomicOrdering::release, AtomicOrdering::acq_rel}); |
889 | } |
890 | |
891 | void LoadOp::build(OpBuilder &builder, OperationState &state, Type type, |
892 | Value addr, unsigned alignment, bool isVolatile, |
893 | bool isNonTemporal, bool isInvariant, |
894 | AtomicOrdering ordering, StringRef syncscope) { |
895 | build(builder, state, type, addr, |
896 | alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, |
897 | isNonTemporal, isInvariant, ordering, |
898 | syncscope.empty() ? nullptr : builder.getStringAttr(syncscope), |
899 | /*access_groups=*/nullptr, |
900 | /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, |
901 | /*tbaa=*/nullptr); |
902 | } |
903 | |
904 | //===----------------------------------------------------------------------===// |
905 | // StoreOp |
906 | //===----------------------------------------------------------------------===// |
907 | |
908 | void StoreOp::getEffects( |
909 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
910 | &effects) { |
911 | effects.emplace_back(MemoryEffects::Write::get(), getAddr()); |
912 | // Volatile operations can have target-specific read-write effects on |
913 | // memory besides the one referred to by the pointer operand. |
914 | // Similarly, atomic operations that are monotonic or stricter cause |
915 | // synchronization that from a language point-of-view, are arbitrary |
916 | // read-writes into memory. |
917 | if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && |
918 | getOrdering() != AtomicOrdering::unordered)) { |
919 | effects.emplace_back(MemoryEffects::Write::get()); |
920 | effects.emplace_back(MemoryEffects::Read::get()); |
921 | } |
922 | } |
923 | |
924 | LogicalResult StoreOp::verify() { |
925 | Type valueType = getValue().getType(); |
926 | return verifyAtomicMemOp(*this, valueType, |
927 | {AtomicOrdering::acquire, AtomicOrdering::acq_rel}); |
928 | } |
929 | |
930 | void StoreOp::build(OpBuilder &builder, OperationState &state, Value value, |
931 | Value addr, unsigned alignment, bool isVolatile, |
932 | bool isNonTemporal, AtomicOrdering ordering, |
933 | StringRef syncscope) { |
934 | build(builder, state, value, addr, |
935 | alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, |
936 | isNonTemporal, ordering, |
937 | syncscope.empty() ? nullptr : builder.getStringAttr(syncscope), |
938 | /*access_groups=*/nullptr, |
939 | /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
940 | } |
941 | |
942 | //===----------------------------------------------------------------------===// |
943 | // CallOp |
944 | //===----------------------------------------------------------------------===// |
945 | |
946 | /// Gets the MLIR Op-like result types of a LLVMFunctionType. |
947 | static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) { |
948 | SmallVector<Type, 1> results; |
949 | Type resultType = calleeType.getReturnType(); |
950 | if (!isa<LLVM::LLVMVoidType>(Val: resultType)) |
951 | results.push_back(Elt: resultType); |
952 | return results; |
953 | } |
954 | |
955 | /// Constructs a LLVMFunctionType from MLIR `results` and `args`. |
956 | static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results, |
957 | ValueRange args) { |
958 | Type resultType; |
959 | if (results.empty()) |
960 | resultType = LLVMVoidType::get(ctx: context); |
961 | else |
962 | resultType = results.front(); |
963 | return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()), |
964 | /*isVarArg=*/false); |
965 | } |
966 | |
967 | void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, |
968 | StringRef callee, ValueRange args) { |
969 | build(builder, state, results, builder.getStringAttr(callee), args); |
970 | } |
971 | |
972 | void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, |
973 | StringAttr callee, ValueRange args) { |
974 | build(builder, state, results, SymbolRefAttr::get(callee), args); |
975 | } |
976 | |
977 | void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, |
978 | FlatSymbolRefAttr callee, ValueRange args) { |
979 | assert(callee && "expected non-null callee in direct call builder" ); |
980 | build(builder, state, results, |
981 | TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)), |
982 | callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, |
983 | /*CConv=*/nullptr, |
984 | /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, |
985 | /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
986 | } |
987 | |
988 | void CallOp::build(OpBuilder &builder, OperationState &state, |
989 | LLVMFunctionType calleeType, StringRef callee, |
990 | ValueRange args) { |
991 | build(builder, state, calleeType, builder.getStringAttr(callee), args); |
992 | } |
993 | |
994 | void CallOp::build(OpBuilder &builder, OperationState &state, |
995 | LLVMFunctionType calleeType, StringAttr callee, |
996 | ValueRange args) { |
997 | build(builder, state, calleeType, SymbolRefAttr::get(callee), args); |
998 | } |
999 | |
1000 | void CallOp::build(OpBuilder &builder, OperationState &state, |
1001 | LLVMFunctionType calleeType, FlatSymbolRefAttr callee, |
1002 | ValueRange args) { |
1003 | build(builder, state, getCallOpResultTypes(calleeType), |
1004 | TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr, |
1005 | /*branch_weights=*/nullptr, /*CConv=*/nullptr, |
1006 | /*access_groups=*/nullptr, |
1007 | /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
1008 | } |
1009 | |
1010 | void CallOp::build(OpBuilder &builder, OperationState &state, |
1011 | LLVMFunctionType calleeType, ValueRange args) { |
1012 | build(builder, state, getCallOpResultTypes(calleeType), |
1013 | TypeAttr::get(calleeType), /*callee=*/nullptr, args, |
1014 | /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, |
1015 | /*CConv=*/nullptr, |
1016 | /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, |
1017 | /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
1018 | } |
1019 | |
1020 | void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, |
1021 | ValueRange args) { |
1022 | auto calleeType = func.getFunctionType(); |
1023 | build(builder, state, getCallOpResultTypes(calleeType), |
1024 | TypeAttr::get(calleeType), SymbolRefAttr::get(func), args, |
1025 | /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, |
1026 | /*CConv=*/nullptr, |
1027 | /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, |
1028 | /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
1029 | } |
1030 | |
1031 | CallInterfaceCallable CallOp::getCallableForCallee() { |
1032 | // Direct call. |
1033 | if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) |
1034 | return calleeAttr; |
1035 | // Indirect call, callee Value is the first operand. |
1036 | return getOperand(0); |
1037 | } |
1038 | |
1039 | void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
1040 | // Direct call. |
1041 | if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { |
1042 | auto symRef = callee.get<SymbolRefAttr>(); |
1043 | return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef)); |
1044 | } |
1045 | // Indirect call, callee Value is the first operand. |
1046 | return setOperand(0, callee.get<Value>()); |
1047 | } |
1048 | |
1049 | Operation::operand_range CallOp::getArgOperands() { |
1050 | return getOperands().drop_front(getCallee().has_value() ? 0 : 1); |
1051 | } |
1052 | |
1053 | MutableOperandRange CallOp::getArgOperandsMutable() { |
1054 | return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1, |
1055 | getCalleeOperands().size()); |
1056 | } |
1057 | |
1058 | /// Verify that an inlinable callsite of a debug-info-bearing function in a |
1059 | /// debug-info-bearing function has a debug location attached to it. This |
1060 | /// mirrors an LLVM IR verifier. |
1061 | static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) { |
1062 | if (callee.isExternal()) |
1063 | return success(); |
1064 | auto parentFunc = callOp->getParentOfType<FunctionOpInterface>(); |
1065 | if (!parentFunc) |
1066 | return success(); |
1067 | |
1068 | auto hasSubprogram = [](Operation *op) { |
1069 | return op->getLoc() |
1070 | ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() != |
1071 | nullptr; |
1072 | }; |
1073 | if (!hasSubprogram(parentFunc) || !hasSubprogram(callee)) |
1074 | return success(); |
1075 | bool containsLoc = !isa<UnknownLoc>(callOp->getLoc()); |
1076 | if (!containsLoc) |
1077 | return callOp.emitError() |
1078 | << "inlinable function call in a function with a DISubprogram " |
1079 | "location must have a debug location" ; |
1080 | return success(); |
1081 | } |
1082 | |
1083 | LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1084 | if (getNumResults() > 1) |
1085 | return emitOpError("must have 0 or 1 result" ); |
1086 | |
1087 | // Type for the callee, we'll get it differently depending if it is a direct |
1088 | // or indirect call. |
1089 | Type fnType; |
1090 | |
1091 | bool isIndirect = false; |
1092 | |
1093 | // If this is an indirect call, the callee attribute is missing. |
1094 | FlatSymbolRefAttr calleeName = getCalleeAttr(); |
1095 | if (!calleeName) { |
1096 | isIndirect = true; |
1097 | if (!getNumOperands()) |
1098 | return emitOpError( |
1099 | "must have either a `callee` attribute or at least an operand" ); |
1100 | auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType()); |
1101 | if (!ptrType) |
1102 | return emitOpError("indirect call expects a pointer as callee: " ) |
1103 | << getOperand(0).getType(); |
1104 | |
1105 | return success(); |
1106 | } else { |
1107 | Operation *callee = |
1108 | symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr()); |
1109 | if (!callee) |
1110 | return emitOpError() |
1111 | << "'" << calleeName.getValue() |
1112 | << "' does not reference a symbol in the current scope" ; |
1113 | auto fn = dyn_cast<LLVMFuncOp>(callee); |
1114 | if (!fn) |
1115 | return emitOpError() << "'" << calleeName.getValue() |
1116 | << "' does not reference a valid LLVM function" ; |
1117 | |
1118 | if (failed(verifyCallOpDebugInfo(*this, fn))) |
1119 | return failure(); |
1120 | fnType = fn.getFunctionType(); |
1121 | } |
1122 | |
1123 | LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType); |
1124 | if (!funcType) |
1125 | return emitOpError("callee does not have a functional type: " ) << fnType; |
1126 | |
1127 | if (funcType.isVarArg() && !getCalleeType()) |
1128 | return emitOpError() << "missing callee type attribute for vararg call" ; |
1129 | |
1130 | // Verify that the operand and result types match the callee. |
1131 | |
1132 | if (!funcType.isVarArg() && |
1133 | funcType.getNumParams() != (getNumOperands() - isIndirect)) |
1134 | return emitOpError() << "incorrect number of operands (" |
1135 | << (getNumOperands() - isIndirect) |
1136 | << ") for callee (expecting: " |
1137 | << funcType.getNumParams() << ")" ; |
1138 | |
1139 | if (funcType.getNumParams() > (getNumOperands() - isIndirect)) |
1140 | return emitOpError() << "incorrect number of operands (" |
1141 | << (getNumOperands() - isIndirect) |
1142 | << ") for varargs callee (expecting at least: " |
1143 | << funcType.getNumParams() << ")" ; |
1144 | |
1145 | for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) |
1146 | if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) |
1147 | return emitOpError() << "operand type mismatch for operand " << i << ": " |
1148 | << getOperand(i + isIndirect).getType() |
1149 | << " != " << funcType.getParamType(i); |
1150 | |
1151 | if (getNumResults() == 0 && |
1152 | !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType())) |
1153 | return emitOpError() << "expected function call to produce a value" ; |
1154 | |
1155 | if (getNumResults() != 0 && |
1156 | llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType())) |
1157 | return emitOpError() |
1158 | << "calling function with void result must not produce values" ; |
1159 | |
1160 | if (getNumResults() > 1) |
1161 | return emitOpError() |
1162 | << "expected LLVM function call to produce 0 or 1 result" ; |
1163 | |
1164 | if (getNumResults() && getResult().getType() != funcType.getReturnType()) |
1165 | return emitOpError() << "result type mismatch: " << getResult().getType() |
1166 | << " != " << funcType.getReturnType(); |
1167 | |
1168 | return success(); |
1169 | } |
1170 | |
1171 | void CallOp::print(OpAsmPrinter &p) { |
1172 | auto callee = getCallee(); |
1173 | bool isDirect = callee.has_value(); |
1174 | |
1175 | LLVMFunctionType calleeType; |
1176 | bool isVarArg = false; |
1177 | |
1178 | if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) { |
1179 | calleeType = *optionalCalleeType; |
1180 | isVarArg = calleeType.isVarArg(); |
1181 | } |
1182 | |
1183 | p << ' '; |
1184 | |
1185 | // Print calling convention. |
1186 | if (getCConv() != LLVM::CConv::C) |
1187 | p << stringifyCConv(getCConv()) << ' '; |
1188 | |
1189 | // Print the direct callee if present as a function attribute, or an indirect |
1190 | // callee (first operand) otherwise. |
1191 | if (isDirect) |
1192 | p.printSymbolName(callee.value()); |
1193 | else |
1194 | p << getOperand(0); |
1195 | |
1196 | auto args = getOperands().drop_front(isDirect ? 0 : 1); |
1197 | p << '(' << args << ')'; |
1198 | |
1199 | if (isVarArg) |
1200 | p << " vararg(" << calleeType << ")" ; |
1201 | |
1202 | p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), |
1203 | {getCConvAttrName(), "callee" , "callee_type" }); |
1204 | |
1205 | p << " : " ; |
1206 | if (!isDirect) |
1207 | p << getOperand(0).getType() << ", " ; |
1208 | |
1209 | // Reconstruct the function MLIR function type from operand and result types. |
1210 | p.printFunctionalType(args.getTypes(), getResultTypes()); |
1211 | } |
1212 | |
1213 | /// Parses the type of a call operation and resolves the operands if the parsing |
1214 | /// succeeds. Returns failure otherwise. |
1215 | static ParseResult parseCallTypeAndResolveOperands( |
1216 | OpAsmParser &parser, OperationState &result, bool isDirect, |
1217 | ArrayRef<OpAsmParser::UnresolvedOperand> operands) { |
1218 | SMLoc trailingTypesLoc = parser.getCurrentLocation(); |
1219 | SmallVector<Type> types; |
1220 | if (parser.parseColonTypeList(result&: types)) |
1221 | return failure(); |
1222 | |
1223 | if (isDirect && types.size() != 1) |
1224 | return parser.emitError(loc: trailingTypesLoc, |
1225 | message: "expected direct call to have 1 trailing type" ); |
1226 | if (!isDirect && types.size() != 2) |
1227 | return parser.emitError(loc: trailingTypesLoc, |
1228 | message: "expected indirect call to have 2 trailing types" ); |
1229 | |
1230 | auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val()); |
1231 | if (!funcType) |
1232 | return parser.emitError(loc: trailingTypesLoc, |
1233 | message: "expected trailing function type" ); |
1234 | if (funcType.getNumResults() > 1) |
1235 | return parser.emitError(loc: trailingTypesLoc, |
1236 | message: "expected function with 0 or 1 result" ); |
1237 | if (funcType.getNumResults() == 1 && |
1238 | llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0))) |
1239 | return parser.emitError(loc: trailingTypesLoc, |
1240 | message: "expected a non-void result type" ); |
1241 | |
1242 | // The head element of the types list matches the callee type for |
1243 | // indirect calls, while the types list is emtpy for direct calls. |
1244 | // Append the function input types to resolve the call operation |
1245 | // operands. |
1246 | llvm::append_range(types, funcType.getInputs()); |
1247 | if (parser.resolveOperands(operands, types, loc: parser.getNameLoc(), |
1248 | result&: result.operands)) |
1249 | return failure(); |
1250 | if (funcType.getNumResults() != 0) |
1251 | result.addTypes(funcType.getResults()); |
1252 | |
1253 | return success(); |
1254 | } |
1255 | |
1256 | /// Parses an optional function pointer operand before the call argument list |
1257 | /// for indirect calls, or stops parsing at the function identifier otherwise. |
1258 | static ParseResult parseOptionalCallFuncPtr( |
1259 | OpAsmParser &parser, |
1260 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) { |
1261 | OpAsmParser::UnresolvedOperand funcPtrOperand; |
1262 | OptionalParseResult parseResult = parser.parseOptionalOperand(result&: funcPtrOperand); |
1263 | if (parseResult.has_value()) { |
1264 | if (failed(result: *parseResult)) |
1265 | return *parseResult; |
1266 | operands.push_back(Elt: funcPtrOperand); |
1267 | } |
1268 | return success(); |
1269 | } |
1270 | |
1271 | // <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use) |
1272 | // `(` ssa-use-list `)` |
1273 | // ( `vararg(` var-arg-func-type `)` )? |
1274 | // attribute-dict? `:` (type `,`)? function-type |
1275 | ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { |
1276 | SymbolRefAttr funcAttr; |
1277 | TypeAttr calleeType; |
1278 | SmallVector<OpAsmParser::UnresolvedOperand> operands; |
1279 | |
1280 | // Default to C Calling Convention if no keyword is provided. |
1281 | result.addAttribute( |
1282 | getCConvAttrName(result.name), |
1283 | CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( |
1284 | parser, result, LLVM::CConv::C))); |
1285 | |
1286 | // Parse a function pointer for indirect calls. |
1287 | if (parseOptionalCallFuncPtr(parser, operands)) |
1288 | return failure(); |
1289 | bool isDirect = operands.empty(); |
1290 | |
1291 | // Parse a function identifier for direct calls. |
1292 | if (isDirect) |
1293 | if (parser.parseAttribute(funcAttr, "callee" , result.attributes)) |
1294 | return failure(); |
1295 | |
1296 | // Parse the function arguments. |
1297 | if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) |
1298 | return failure(); |
1299 | |
1300 | bool isVarArg = parser.parseOptionalKeyword("vararg" ).succeeded(); |
1301 | if (isVarArg) { |
1302 | if (parser.parseLParen().failed() || |
1303 | parser.parseAttribute(calleeType, "callee_type" , result.attributes) |
1304 | .failed() || |
1305 | parser.parseRParen().failed()) |
1306 | return failure(); |
1307 | } |
1308 | |
1309 | if (parser.parseOptionalAttrDict(result.attributes)) |
1310 | return failure(); |
1311 | |
1312 | // Parse the trailing type list and resolve the operands. |
1313 | return parseCallTypeAndResolveOperands(parser, result, isDirect, operands); |
1314 | } |
1315 | |
1316 | LLVMFunctionType CallOp::getCalleeFunctionType() { |
1317 | if (getCalleeType()) |
1318 | return *getCalleeType(); |
1319 | return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands()); |
1320 | } |
1321 | |
1322 | ///===---------------------------------------------------------------------===// |
1323 | /// LLVM::InvokeOp |
1324 | ///===---------------------------------------------------------------------===// |
1325 | |
1326 | void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, |
1327 | ValueRange ops, Block *normal, ValueRange normalOps, |
1328 | Block *unwind, ValueRange unwindOps) { |
1329 | auto calleeType = func.getFunctionType(); |
1330 | build(builder, state, getCallOpResultTypes(calleeType), |
1331 | TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps, |
1332 | unwindOps, nullptr, nullptr, normal, unwind); |
1333 | } |
1334 | |
1335 | void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, |
1336 | FlatSymbolRefAttr callee, ValueRange ops, Block *normal, |
1337 | ValueRange normalOps, Block *unwind, |
1338 | ValueRange unwindOps) { |
1339 | build(builder, state, tys, |
1340 | TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee, |
1341 | ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind); |
1342 | } |
1343 | |
1344 | void InvokeOp::build(OpBuilder &builder, OperationState &state, |
1345 | LLVMFunctionType calleeType, FlatSymbolRefAttr callee, |
1346 | ValueRange ops, Block *normal, ValueRange normalOps, |
1347 | Block *unwind, ValueRange unwindOps) { |
1348 | build(builder, state, getCallOpResultTypes(calleeType), |
1349 | TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr, |
1350 | nullptr, normal, unwind); |
1351 | } |
1352 | |
1353 | SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { |
1354 | assert(index < getNumSuccessors() && "invalid successor index" ); |
1355 | return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() |
1356 | : getUnwindDestOperandsMutable()); |
1357 | } |
1358 | |
1359 | CallInterfaceCallable InvokeOp::getCallableForCallee() { |
1360 | // Direct call. |
1361 | if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) |
1362 | return calleeAttr; |
1363 | // Indirect call, callee Value is the first operand. |
1364 | return getOperand(0); |
1365 | } |
1366 | |
1367 | void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
1368 | // Direct call. |
1369 | if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { |
1370 | auto symRef = callee.get<SymbolRefAttr>(); |
1371 | return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef)); |
1372 | } |
1373 | // Indirect call, callee Value is the first operand. |
1374 | return setOperand(0, callee.get<Value>()); |
1375 | } |
1376 | |
1377 | Operation::operand_range InvokeOp::getArgOperands() { |
1378 | return getOperands().drop_front(getCallee().has_value() ? 0 : 1); |
1379 | } |
1380 | |
1381 | MutableOperandRange InvokeOp::getArgOperandsMutable() { |
1382 | return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1, |
1383 | getCalleeOperands().size()); |
1384 | } |
1385 | |
1386 | LogicalResult InvokeOp::verify() { |
1387 | if (getNumResults() > 1) |
1388 | return emitOpError("must have 0 or 1 result" ); |
1389 | |
1390 | Block *unwindDest = getUnwindDest(); |
1391 | if (unwindDest->empty()) |
1392 | return emitError("must have at least one operation in unwind destination" ); |
1393 | |
1394 | // In unwind destination, first operation must be LandingpadOp |
1395 | if (!isa<LandingpadOp>(unwindDest->front())) |
1396 | return emitError("first operation in unwind destination should be a " |
1397 | "llvm.landingpad operation" ); |
1398 | |
1399 | return success(); |
1400 | } |
1401 | |
1402 | void InvokeOp::print(OpAsmPrinter &p) { |
1403 | auto callee = getCallee(); |
1404 | bool isDirect = callee.has_value(); |
1405 | |
1406 | LLVMFunctionType calleeType; |
1407 | bool isVarArg = false; |
1408 | |
1409 | if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) { |
1410 | calleeType = *optionalCalleeType; |
1411 | isVarArg = calleeType.isVarArg(); |
1412 | } |
1413 | |
1414 | p << ' '; |
1415 | |
1416 | // Print calling convention. |
1417 | if (getCConv() != LLVM::CConv::C) |
1418 | p << stringifyCConv(getCConv()) << ' '; |
1419 | |
1420 | // Either function name or pointer |
1421 | if (isDirect) |
1422 | p.printSymbolName(callee.value()); |
1423 | else |
1424 | p << getOperand(0); |
1425 | |
1426 | p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; |
1427 | p << " to " ; |
1428 | p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); |
1429 | p << " unwind " ; |
1430 | p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); |
1431 | |
1432 | if (isVarArg) |
1433 | p << " vararg(" << calleeType << ")" ; |
1434 | |
1435 | p.printOptionalAttrDict((*this)->getAttrs(), |
1436 | {InvokeOp::getOperandSegmentSizeAttr(), "callee" , |
1437 | "callee_type" , InvokeOp::getCConvAttrName()}); |
1438 | |
1439 | p << " : " ; |
1440 | if (!isDirect) |
1441 | p << getOperand(0).getType() << ", " ; |
1442 | p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), |
1443 | getResultTypes()); |
1444 | } |
1445 | |
1446 | // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use) |
1447 | // `(` ssa-use-list `)` |
1448 | // `to` bb-id (`[` ssa-use-and-type-list `]`)? |
1449 | // `unwind` bb-id (`[` ssa-use-and-type-list `]`)? |
1450 | // ( `vararg(` var-arg-func-type `)` )? |
1451 | // attribute-dict? `:` (type `,`)? function-type |
1452 | ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { |
1453 | SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; |
1454 | SymbolRefAttr funcAttr; |
1455 | TypeAttr calleeType; |
1456 | Block *normalDest, *unwindDest; |
1457 | SmallVector<Value, 4> normalOperands, unwindOperands; |
1458 | Builder &builder = parser.getBuilder(); |
1459 | |
1460 | // Default to C Calling Convention if no keyword is provided. |
1461 | result.addAttribute( |
1462 | getCConvAttrName(result.name), |
1463 | CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( |
1464 | parser, result, LLVM::CConv::C))); |
1465 | |
1466 | // Parse a function pointer for indirect calls. |
1467 | if (parseOptionalCallFuncPtr(parser, operands)) |
1468 | return failure(); |
1469 | bool isDirect = operands.empty(); |
1470 | |
1471 | // Parse a function identifier for direct calls. |
1472 | if (isDirect && parser.parseAttribute(funcAttr, "callee" , result.attributes)) |
1473 | return failure(); |
1474 | |
1475 | // Parse the function arguments. |
1476 | if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || |
1477 | parser.parseKeyword("to" ) || |
1478 | parser.parseSuccessorAndUseList(normalDest, normalOperands) || |
1479 | parser.parseKeyword("unwind" ) || |
1480 | parser.parseSuccessorAndUseList(unwindDest, unwindOperands)) |
1481 | return failure(); |
1482 | |
1483 | bool isVarArg = parser.parseOptionalKeyword("vararg" ).succeeded(); |
1484 | if (isVarArg) { |
1485 | if (parser.parseLParen().failed() || |
1486 | parser.parseAttribute(calleeType, "callee_type" , result.attributes) |
1487 | .failed() || |
1488 | parser.parseRParen().failed()) |
1489 | return failure(); |
1490 | } |
1491 | |
1492 | if (parser.parseOptionalAttrDict(result.attributes)) |
1493 | return failure(); |
1494 | |
1495 | // Parse the trailing type list and resolve the function operands. |
1496 | if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands)) |
1497 | return failure(); |
1498 | |
1499 | result.addSuccessors({normalDest, unwindDest}); |
1500 | result.addOperands(normalOperands); |
1501 | result.addOperands(unwindOperands); |
1502 | |
1503 | result.addAttribute(InvokeOp::getOperandSegmentSizeAttr(), |
1504 | builder.getDenseI32ArrayAttr( |
1505 | {static_cast<int32_t>(operands.size()), |
1506 | static_cast<int32_t>(normalOperands.size()), |
1507 | static_cast<int32_t>(unwindOperands.size())})); |
1508 | return success(); |
1509 | } |
1510 | |
1511 | LLVMFunctionType InvokeOp::getCalleeFunctionType() { |
1512 | if (getCalleeType()) |
1513 | return *getCalleeType(); |
1514 | return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands()); |
1515 | } |
1516 | |
1517 | ///===----------------------------------------------------------------------===// |
1518 | /// Verifying/Printing/Parsing for LLVM::LandingpadOp. |
1519 | ///===----------------------------------------------------------------------===// |
1520 | |
1521 | LogicalResult LandingpadOp::verify() { |
1522 | Value value; |
1523 | if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { |
1524 | if (!func.getPersonality()) |
1525 | return emitError( |
1526 | "llvm.landingpad needs to be in a function with a personality" ); |
1527 | } |
1528 | |
1529 | // Consistency of llvm.landingpad result types is checked in |
1530 | // LLVMFuncOp::verify(). |
1531 | |
1532 | if (!getCleanup() && getOperands().empty()) |
1533 | return emitError("landingpad instruction expects at least one clause or " |
1534 | "cleanup attribute" ); |
1535 | |
1536 | for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { |
1537 | value = getOperand(idx); |
1538 | bool isFilter = llvm::isa<LLVMArrayType>(value.getType()); |
1539 | if (isFilter) { |
1540 | // FIXME: Verify filter clauses when arrays are appropriately handled |
1541 | } else { |
1542 | // catch - global addresses only. |
1543 | // Bitcast ops should have global addresses as their args. |
1544 | if (auto bcOp = value.getDefiningOp<BitcastOp>()) { |
1545 | if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) |
1546 | continue; |
1547 | return emitError("constant clauses expected" ).attachNote(bcOp.getLoc()) |
1548 | << "global addresses expected as operand to " |
1549 | "bitcast used in clauses for landingpad" ; |
1550 | } |
1551 | // ZeroOp and AddressOfOp allowed |
1552 | if (value.getDefiningOp<ZeroOp>()) |
1553 | continue; |
1554 | if (value.getDefiningOp<AddressOfOp>()) |
1555 | continue; |
1556 | return emitError("clause #" ) |
1557 | << idx << " is not a known constant - null, addressof, bitcast" ; |
1558 | } |
1559 | } |
1560 | return success(); |
1561 | } |
1562 | |
1563 | void LandingpadOp::print(OpAsmPrinter &p) { |
1564 | p << (getCleanup() ? " cleanup " : " " ); |
1565 | |
1566 | // Clauses |
1567 | for (auto value : getOperands()) { |
1568 | // Similar to llvm - if clause is an array type then it is filter |
1569 | // clause else catch clause |
1570 | bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType()); |
1571 | p << '(' << (isArrayTy ? "filter " : "catch " ) << value << " : " |
1572 | << value.getType() << ") " ; |
1573 | } |
1574 | |
1575 | p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup" }); |
1576 | |
1577 | p << ": " << getType(); |
1578 | } |
1579 | |
1580 | // <operation> ::= `llvm.landingpad` `cleanup`? |
1581 | // ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? |
1582 | ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { |
1583 | // Check for cleanup |
1584 | if (succeeded(parser.parseOptionalKeyword("cleanup" ))) |
1585 | result.addAttribute("cleanup" , parser.getBuilder().getUnitAttr()); |
1586 | |
1587 | // Parse clauses with types |
1588 | while (succeeded(parser.parseOptionalLParen()) && |
1589 | (succeeded(parser.parseOptionalKeyword("filter" )) || |
1590 | succeeded(parser.parseOptionalKeyword("catch" )))) { |
1591 | OpAsmParser::UnresolvedOperand operand; |
1592 | Type ty; |
1593 | if (parser.parseOperand(operand) || parser.parseColon() || |
1594 | parser.parseType(ty) || |
1595 | parser.resolveOperand(operand, ty, result.operands) || |
1596 | parser.parseRParen()) |
1597 | return failure(); |
1598 | } |
1599 | |
1600 | Type type; |
1601 | if (parser.parseColon() || parser.parseType(type)) |
1602 | return failure(); |
1603 | |
1604 | result.addTypes(type); |
1605 | return success(); |
1606 | } |
1607 | |
1608 | //===----------------------------------------------------------------------===// |
1609 | // ExtractValueOp |
1610 | //===----------------------------------------------------------------------===// |
1611 | |
1612 | /// Extract the type at `position` in the LLVM IR aggregate type |
1613 | /// `containerType`. Each element of `position` is an index into a nested |
1614 | /// aggregate type. Return the resulting type or emit an error. |
1615 | static Type ( |
1616 | function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType, |
1617 | ArrayRef<int64_t> position) { |
1618 | Type llvmType = containerType; |
1619 | if (!isCompatibleType(type: containerType)) { |
1620 | emitError("expected LLVM IR Dialect type, got " ) << containerType; |
1621 | return {}; |
1622 | } |
1623 | |
1624 | // Infer the element type from the structure type: iteratively step inside the |
1625 | // type by taking the element type, indexed by the position attribute for |
1626 | // structures. Check the position index before accessing, it is supposed to |
1627 | // be in bounds. |
1628 | for (int64_t idx : position) { |
1629 | if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) { |
1630 | if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) { |
1631 | emitError("position out of bounds: " ) << idx; |
1632 | return {}; |
1633 | } |
1634 | llvmType = arrayType.getElementType(); |
1635 | } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) { |
1636 | if (idx < 0 || |
1637 | static_cast<unsigned>(idx) >= structType.getBody().size()) { |
1638 | emitError("position out of bounds: " ) << idx; |
1639 | return {}; |
1640 | } |
1641 | llvmType = structType.getBody()[idx]; |
1642 | } else { |
1643 | emitError("expected LLVM IR structure/array type, got: " ) << llvmType; |
1644 | return {}; |
1645 | } |
1646 | } |
1647 | return llvmType; |
1648 | } |
1649 | |
1650 | /// Extract the type at `position` in the wrapped LLVM IR aggregate type |
1651 | /// `containerType`. |
1652 | static Type (Type llvmType, |
1653 | ArrayRef<int64_t> position) { |
1654 | for (int64_t idx : position) { |
1655 | if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) |
1656 | llvmType = structType.getBody()[idx]; |
1657 | else |
1658 | llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType(); |
1659 | } |
1660 | return llvmType; |
1661 | } |
1662 | |
1663 | OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) { |
1664 | auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); |
1665 | OpFoldResult result = {}; |
1666 | while (insertValueOp) { |
1667 | if (getPosition() == insertValueOp.getPosition()) |
1668 | return insertValueOp.getValue(); |
1669 | unsigned min = |
1670 | std::min(getPosition().size(), insertValueOp.getPosition().size()); |
1671 | // If one is fully prefix of the other, stop propagating back as it will |
1672 | // miss dependencies. For instance, %3 should not fold to %f0 in the |
1673 | // following example: |
1674 | // ``` |
1675 | // %1 = llvm.insertvalue %f0, %0[0, 0] : |
1676 | // !llvm.array<4 x !llvm.array<4 x f32>> |
1677 | // %2 = llvm.insertvalue %arr, %1[0] : |
1678 | // !llvm.array<4 x !llvm.array<4 x f32>> |
1679 | // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>> |
1680 | // ``` |
1681 | if (getPosition().take_front(min) == |
1682 | insertValueOp.getPosition().take_front(min)) |
1683 | return result; |
1684 | |
1685 | // If neither a prefix, nor the exact position, we can extract out of the |
1686 | // value being inserted into. Moreover, we can try again if that operand |
1687 | // is itself an insertvalue expression. |
1688 | getContainerMutable().assign(insertValueOp.getContainer()); |
1689 | result = getResult(); |
1690 | insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); |
1691 | } |
1692 | return result; |
1693 | } |
1694 | |
1695 | LogicalResult ExtractValueOp::verify() { |
1696 | auto emitError = [this](StringRef msg) { return emitOpError(msg); }; |
1697 | Type valueType = getInsertExtractValueElementType( |
1698 | emitError, getContainer().getType(), getPosition()); |
1699 | if (!valueType) |
1700 | return failure(); |
1701 | |
1702 | if (getRes().getType() != valueType) |
1703 | return emitOpError() << "Type mismatch: extracting from " |
1704 | << getContainer().getType() << " should produce " |
1705 | << valueType << " but this op returns " |
1706 | << getRes().getType(); |
1707 | return success(); |
1708 | } |
1709 | |
1710 | void ExtractValueOp::build(OpBuilder &builder, OperationState &state, |
1711 | Value container, ArrayRef<int64_t> position) { |
1712 | build(builder, state, |
1713 | getInsertExtractValueElementType(container.getType(), position), |
1714 | container, builder.getAttr<DenseI64ArrayAttr>(position)); |
1715 | } |
1716 | |
1717 | //===----------------------------------------------------------------------===// |
1718 | // InsertValueOp |
1719 | //===----------------------------------------------------------------------===// |
1720 | |
1721 | /// Infer the value type from the container type and position. |
1722 | static ParseResult |
1723 | (AsmParser &parser, Type &valueType, |
1724 | Type containerType, |
1725 | DenseI64ArrayAttr position) { |
1726 | valueType = getInsertExtractValueElementType( |
1727 | [&](StringRef msg) { |
1728 | return parser.emitError(loc: parser.getCurrentLocation(), message: msg); |
1729 | }, |
1730 | containerType, position.asArrayRef()); |
1731 | return success(isSuccess: !!valueType); |
1732 | } |
1733 | |
1734 | /// Nothing to print for an inferred type. |
1735 | static void (AsmPrinter &printer, |
1736 | Operation *op, Type valueType, |
1737 | Type containerType, |
1738 | DenseI64ArrayAttr position) {} |
1739 | |
1740 | LogicalResult InsertValueOp::verify() { |
1741 | auto emitError = [this](StringRef msg) { return emitOpError(msg); }; |
1742 | Type valueType = getInsertExtractValueElementType( |
1743 | emitError, getContainer().getType(), getPosition()); |
1744 | if (!valueType) |
1745 | return failure(); |
1746 | |
1747 | if (getValue().getType() != valueType) |
1748 | return emitOpError() << "Type mismatch: cannot insert " |
1749 | << getValue().getType() << " into " |
1750 | << getContainer().getType(); |
1751 | |
1752 | return success(); |
1753 | } |
1754 | |
1755 | //===----------------------------------------------------------------------===// |
1756 | // ReturnOp |
1757 | //===----------------------------------------------------------------------===// |
1758 | |
1759 | LogicalResult ReturnOp::verify() { |
1760 | auto parent = (*this)->getParentOfType<LLVMFuncOp>(); |
1761 | if (!parent) |
1762 | return success(); |
1763 | |
1764 | Type expectedType = parent.getFunctionType().getReturnType(); |
1765 | if (llvm::isa<LLVMVoidType>(expectedType)) { |
1766 | if (!getArg()) |
1767 | return success(); |
1768 | InFlightDiagnostic diag = emitOpError("expected no operands" ); |
1769 | diag.attachNote(parent->getLoc()) << "when returning from function" ; |
1770 | return diag; |
1771 | } |
1772 | if (!getArg()) { |
1773 | if (llvm::isa<LLVMVoidType>(expectedType)) |
1774 | return success(); |
1775 | InFlightDiagnostic diag = emitOpError("expected 1 operand" ); |
1776 | diag.attachNote(parent->getLoc()) << "when returning from function" ; |
1777 | return diag; |
1778 | } |
1779 | if (expectedType != getArg().getType()) { |
1780 | InFlightDiagnostic diag = emitOpError("mismatching result types" ); |
1781 | diag.attachNote(parent->getLoc()) << "when returning from function" ; |
1782 | return diag; |
1783 | } |
1784 | return success(); |
1785 | } |
1786 | |
1787 | //===----------------------------------------------------------------------===// |
1788 | // Verifier for LLVM::AddressOfOp. |
1789 | //===----------------------------------------------------------------------===// |
1790 | |
1791 | static Operation *parentLLVMModule(Operation *op) { |
1792 | Operation *module = op->getParentOp(); |
1793 | while (module && !satisfiesLLVMModule(op: module)) |
1794 | module = module->getParentOp(); |
1795 | assert(module && "unexpected operation outside of a module" ); |
1796 | return module; |
1797 | } |
1798 | |
1799 | GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) { |
1800 | return dyn_cast_or_null<GlobalOp>( |
1801 | symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); |
1802 | } |
1803 | |
1804 | LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) { |
1805 | return dyn_cast_or_null<LLVMFuncOp>( |
1806 | symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); |
1807 | } |
1808 | |
1809 | LogicalResult |
1810 | AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1811 | Operation *symbol = |
1812 | symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()); |
1813 | |
1814 | auto global = dyn_cast_or_null<GlobalOp>(symbol); |
1815 | auto function = dyn_cast_or_null<LLVMFuncOp>(symbol); |
1816 | |
1817 | if (!global && !function) |
1818 | return emitOpError( |
1819 | "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'" ); |
1820 | |
1821 | LLVMPointerType type = getType(); |
1822 | if (global && global.getAddrSpace() != type.getAddressSpace()) |
1823 | return emitOpError("pointer address space must match address space of the " |
1824 | "referenced global" ); |
1825 | |
1826 | return success(); |
1827 | } |
1828 | |
1829 | //===----------------------------------------------------------------------===// |
1830 | // Verifier for LLVM::ComdatOp. |
1831 | //===----------------------------------------------------------------------===// |
1832 | |
1833 | void ComdatOp::build(OpBuilder &builder, OperationState &result, |
1834 | StringRef symName) { |
1835 | result.addAttribute(getSymNameAttrName(result.name), |
1836 | builder.getStringAttr(symName)); |
1837 | Region *body = result.addRegion(); |
1838 | body->emplaceBlock(); |
1839 | } |
1840 | |
1841 | LogicalResult ComdatOp::verifyRegions() { |
1842 | Region &body = getBody(); |
1843 | for (Operation &op : body.getOps()) |
1844 | if (!isa<ComdatSelectorOp>(op)) |
1845 | return op.emitError( |
1846 | "only comdat selector symbols can appear in a comdat region" ); |
1847 | |
1848 | return success(); |
1849 | } |
1850 | |
1851 | //===----------------------------------------------------------------------===// |
1852 | // Builder, printer and verifier for LLVM::GlobalOp. |
1853 | //===----------------------------------------------------------------------===// |
1854 | |
1855 | void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, |
1856 | bool isConstant, Linkage linkage, StringRef name, |
1857 | Attribute value, uint64_t alignment, unsigned addrSpace, |
1858 | bool dsoLocal, bool threadLocal, SymbolRefAttr comdat, |
1859 | ArrayRef<NamedAttribute> attrs, |
1860 | DIGlobalVariableExpressionAttr dbgExpr) { |
1861 | result.addAttribute(getSymNameAttrName(result.name), |
1862 | builder.getStringAttr(name)); |
1863 | result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); |
1864 | if (isConstant) |
1865 | result.addAttribute(getConstantAttrName(result.name), |
1866 | builder.getUnitAttr()); |
1867 | if (value) |
1868 | result.addAttribute(getValueAttrName(result.name), value); |
1869 | if (dsoLocal) |
1870 | result.addAttribute(getDsoLocalAttrName(result.name), |
1871 | builder.getUnitAttr()); |
1872 | if (threadLocal) |
1873 | result.addAttribute(getThreadLocal_AttrName(result.name), |
1874 | builder.getUnitAttr()); |
1875 | if (comdat) |
1876 | result.addAttribute(getComdatAttrName(result.name), comdat); |
1877 | |
1878 | // Only add an alignment attribute if the "alignment" input |
1879 | // is different from 0. The value must also be a power of two, but |
1880 | // this is tested in GlobalOp::verify, not here. |
1881 | if (alignment != 0) |
1882 | result.addAttribute(getAlignmentAttrName(result.name), |
1883 | builder.getI64IntegerAttr(alignment)); |
1884 | |
1885 | result.addAttribute(getLinkageAttrName(result.name), |
1886 | LinkageAttr::get(builder.getContext(), linkage)); |
1887 | if (addrSpace != 0) |
1888 | result.addAttribute(getAddrSpaceAttrName(result.name), |
1889 | builder.getI32IntegerAttr(addrSpace)); |
1890 | result.attributes.append(attrs.begin(), attrs.end()); |
1891 | |
1892 | if (dbgExpr) |
1893 | result.addAttribute(getDbgExprAttrName(result.name), dbgExpr); |
1894 | |
1895 | result.addRegion(); |
1896 | } |
1897 | |
1898 | void GlobalOp::print(OpAsmPrinter &p) { |
1899 | p << ' ' << stringifyLinkage(getLinkage()) << ' '; |
1900 | StringRef visibility = stringifyVisibility(getVisibility_()); |
1901 | if (!visibility.empty()) |
1902 | p << visibility << ' '; |
1903 | if (getThreadLocal_()) |
1904 | p << "thread_local " ; |
1905 | if (auto unnamedAddr = getUnnamedAddr()) { |
1906 | StringRef str = stringifyUnnamedAddr(*unnamedAddr); |
1907 | if (!str.empty()) |
1908 | p << str << ' '; |
1909 | } |
1910 | if (getConstant()) |
1911 | p << "constant " ; |
1912 | p.printSymbolName(getSymName()); |
1913 | p << '('; |
1914 | if (auto value = getValueOrNull()) |
1915 | p.printAttribute(value); |
1916 | p << ')'; |
1917 | if (auto comdat = getComdat()) |
1918 | p << " comdat(" << *comdat << ')'; |
1919 | |
1920 | // Note that the alignment attribute is printed using the |
1921 | // default syntax here, even though it is an inherent attribute |
1922 | // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) |
1923 | p.printOptionalAttrDict((*this)->getAttrs(), |
1924 | {SymbolTable::getSymbolAttrName(), |
1925 | getGlobalTypeAttrName(), getConstantAttrName(), |
1926 | getValueAttrName(), getLinkageAttrName(), |
1927 | getUnnamedAddrAttrName(), getThreadLocal_AttrName(), |
1928 | getVisibility_AttrName(), getComdatAttrName(), |
1929 | getUnnamedAddrAttrName()}); |
1930 | |
1931 | // Print the trailing type unless it's a string global. |
1932 | if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) |
1933 | return; |
1934 | p << " : " << getType(); |
1935 | |
1936 | Region &initializer = getInitializerRegion(); |
1937 | if (!initializer.empty()) { |
1938 | p << ' '; |
1939 | p.printRegion(initializer, /*printEntryBlockArgs=*/false); |
1940 | } |
1941 | } |
1942 | |
1943 | static LogicalResult verifyComdat(Operation *op, |
1944 | std::optional<SymbolRefAttr> attr) { |
1945 | if (!attr) |
1946 | return success(); |
1947 | |
1948 | auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr); |
1949 | if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector)) |
1950 | return op->emitError() << "expected comdat symbol" ; |
1951 | |
1952 | return success(); |
1953 | } |
1954 | |
1955 | // operation ::= `llvm.mlir.global` linkage? visibility? |
1956 | // (`unnamed_addr` | `local_unnamed_addr`)? |
1957 | // `thread_local`? `constant`? `@` identifier |
1958 | // `(` attribute? `)` (`comdat(` symbol-ref-id `)`)? |
1959 | // attribute-list? (`:` type)? region? |
1960 | // |
1961 | // The type can be omitted for string attributes, in which case it will be |
1962 | // inferred from the value of the string as [strlen(value) x i8]. |
1963 | ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { |
1964 | MLIRContext *ctx = parser.getContext(); |
1965 | // Parse optional linkage, default to External. |
1966 | result.addAttribute(getLinkageAttrName(result.name), |
1967 | LLVM::LinkageAttr::get( |
1968 | ctx, parseOptionalLLVMKeyword<Linkage>( |
1969 | parser, result, LLVM::Linkage::External))); |
1970 | |
1971 | // Parse optional visibility, default to Default. |
1972 | result.addAttribute(getVisibility_AttrName(result.name), |
1973 | parser.getBuilder().getI64IntegerAttr( |
1974 | parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>( |
1975 | parser, result, LLVM::Visibility::Default))); |
1976 | |
1977 | // Parse optional UnnamedAddr, default to None. |
1978 | result.addAttribute(getUnnamedAddrAttrName(result.name), |
1979 | parser.getBuilder().getI64IntegerAttr( |
1980 | parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( |
1981 | parser, result, LLVM::UnnamedAddr::None))); |
1982 | |
1983 | if (succeeded(parser.parseOptionalKeyword("thread_local" ))) |
1984 | result.addAttribute(getThreadLocal_AttrName(result.name), |
1985 | parser.getBuilder().getUnitAttr()); |
1986 | |
1987 | if (succeeded(parser.parseOptionalKeyword("constant" ))) |
1988 | result.addAttribute(getConstantAttrName(result.name), |
1989 | parser.getBuilder().getUnitAttr()); |
1990 | |
1991 | StringAttr name; |
1992 | if (parser.parseSymbolName(name, getSymNameAttrName(result.name), |
1993 | result.attributes) || |
1994 | parser.parseLParen()) |
1995 | return failure(); |
1996 | |
1997 | Attribute value; |
1998 | if (parser.parseOptionalRParen()) { |
1999 | if (parser.parseAttribute(value, getValueAttrName(result.name), |
2000 | result.attributes) || |
2001 | parser.parseRParen()) |
2002 | return failure(); |
2003 | } |
2004 | |
2005 | if (succeeded(parser.parseOptionalKeyword("comdat" ))) { |
2006 | SymbolRefAttr comdat; |
2007 | if (parser.parseLParen() || parser.parseAttribute(comdat) || |
2008 | parser.parseRParen()) |
2009 | return failure(); |
2010 | |
2011 | result.addAttribute(getComdatAttrName(result.name), comdat); |
2012 | } |
2013 | |
2014 | SmallVector<Type, 1> types; |
2015 | if (parser.parseOptionalAttrDict(result.attributes) || |
2016 | parser.parseOptionalColonTypeList(types)) |
2017 | return failure(); |
2018 | |
2019 | if (types.size() > 1) |
2020 | return parser.emitError(parser.getNameLoc(), "expected zero or one type" ); |
2021 | |
2022 | Region &initRegion = *result.addRegion(); |
2023 | if (types.empty()) { |
2024 | if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) { |
2025 | MLIRContext *context = parser.getContext(); |
2026 | auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), |
2027 | strAttr.getValue().size()); |
2028 | types.push_back(arrayType); |
2029 | } else { |
2030 | return parser.emitError(parser.getNameLoc(), |
2031 | "type can only be omitted for string globals" ); |
2032 | } |
2033 | } else { |
2034 | OptionalParseResult parseResult = |
2035 | parser.parseOptionalRegion(initRegion, /*arguments=*/{}, |
2036 | /*argTypes=*/{}); |
2037 | if (parseResult.has_value() && failed(*parseResult)) |
2038 | return failure(); |
2039 | } |
2040 | |
2041 | result.addAttribute(getGlobalTypeAttrName(result.name), |
2042 | TypeAttr::get(types[0])); |
2043 | return success(); |
2044 | } |
2045 | |
2046 | static bool isZeroAttribute(Attribute value) { |
2047 | if (auto intValue = llvm::dyn_cast<IntegerAttr>(value)) |
2048 | return intValue.getValue().isZero(); |
2049 | if (auto fpValue = llvm::dyn_cast<FloatAttr>(value)) |
2050 | return fpValue.getValue().isZero(); |
2051 | if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(Val&: value)) |
2052 | return isZeroAttribute(value: splatValue.getSplatValue<Attribute>()); |
2053 | if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value)) |
2054 | return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); |
2055 | if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value)) |
2056 | return llvm::all_of(arrayValue.getValue(), isZeroAttribute); |
2057 | return false; |
2058 | } |
2059 | |
2060 | LogicalResult GlobalOp::verify() { |
2061 | bool validType = isCompatibleOuterType(getType()) |
2062 | ? !llvm::isa<LLVMVoidType, LLVMTokenType, |
2063 | LLVMMetadataType, LLVMLabelType>(getType()) |
2064 | : llvm::isa<PointerElementTypeInterface>(getType()); |
2065 | if (!validType) |
2066 | return emitOpError( |
2067 | "expects type to be a valid element type for an LLVM global" ); |
2068 | if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) |
2069 | return emitOpError("must appear at the module level" ); |
2070 | |
2071 | if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) { |
2072 | auto type = llvm::dyn_cast<LLVMArrayType>(getType()); |
2073 | IntegerType elementType = |
2074 | type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr; |
2075 | if (!elementType || elementType.getWidth() != 8 || |
2076 | type.getNumElements() != strAttr.getValue().size()) |
2077 | return emitOpError( |
2078 | "requires an i8 array type of the length equal to that of the string " |
2079 | "attribute" ); |
2080 | } |
2081 | |
2082 | if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) { |
2083 | if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal)) |
2084 | return emitOpError() |
2085 | << "this target extension type cannot be used in a global" ; |
2086 | |
2087 | if (Attribute value = getValueOrNull()) |
2088 | return emitOpError() << "global with target extension type can only be " |
2089 | "initialized with zero-initializer" ; |
2090 | } |
2091 | |
2092 | if (getLinkage() == Linkage::Common) { |
2093 | if (Attribute value = getValueOrNull()) { |
2094 | if (!isZeroAttribute(value)) { |
2095 | return emitOpError() |
2096 | << "expected zero value for '" |
2097 | << stringifyLinkage(Linkage::Common) << "' linkage" ; |
2098 | } |
2099 | } |
2100 | } |
2101 | |
2102 | if (getLinkage() == Linkage::Appending) { |
2103 | if (!llvm::isa<LLVMArrayType>(getType())) { |
2104 | return emitOpError() << "expected array type for '" |
2105 | << stringifyLinkage(Linkage::Appending) |
2106 | << "' linkage" ; |
2107 | } |
2108 | } |
2109 | |
2110 | if (failed(verifyComdat(*this, getComdat()))) |
2111 | return failure(); |
2112 | |
2113 | std::optional<uint64_t> alignAttr = getAlignment(); |
2114 | if (alignAttr.has_value()) { |
2115 | uint64_t value = alignAttr.value(); |
2116 | if (!llvm::isPowerOf2_64(value)) |
2117 | return emitError() << "alignment attribute is not a power of 2" ; |
2118 | } |
2119 | |
2120 | return success(); |
2121 | } |
2122 | |
2123 | LogicalResult GlobalOp::verifyRegions() { |
2124 | if (Block *b = getInitializerBlock()) { |
2125 | ReturnOp ret = cast<ReturnOp>(b->getTerminator()); |
2126 | if (ret.operand_type_begin() == ret.operand_type_end()) |
2127 | return emitOpError("initializer region cannot return void" ); |
2128 | if (*ret.operand_type_begin() != getType()) |
2129 | return emitOpError("initializer region type " ) |
2130 | << *ret.operand_type_begin() << " does not match global type " |
2131 | << getType(); |
2132 | |
2133 | for (Operation &op : *b) { |
2134 | auto iface = dyn_cast<MemoryEffectOpInterface>(op); |
2135 | if (!iface || !iface.hasNoEffect()) |
2136 | return op.emitError() |
2137 | << "ops with side effects not allowed in global initializers" ; |
2138 | } |
2139 | |
2140 | if (getValueOrNull()) |
2141 | return emitOpError("cannot have both initializer value and region" ); |
2142 | } |
2143 | |
2144 | return success(); |
2145 | } |
2146 | |
2147 | //===----------------------------------------------------------------------===// |
2148 | // LLVM::GlobalCtorsOp |
2149 | //===----------------------------------------------------------------------===// |
2150 | |
2151 | LogicalResult |
2152 | GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
2153 | for (Attribute ctor : getCtors()) { |
2154 | if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this, |
2155 | symbolTable))) |
2156 | return failure(); |
2157 | } |
2158 | return success(); |
2159 | } |
2160 | |
2161 | LogicalResult GlobalCtorsOp::verify() { |
2162 | if (getCtors().size() != getPriorities().size()) |
2163 | return emitError( |
2164 | "mismatch between the number of ctors and the number of priorities" ); |
2165 | return success(); |
2166 | } |
2167 | |
2168 | //===----------------------------------------------------------------------===// |
2169 | // LLVM::GlobalDtorsOp |
2170 | //===----------------------------------------------------------------------===// |
2171 | |
2172 | LogicalResult |
2173 | GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
2174 | for (Attribute dtor : getDtors()) { |
2175 | if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this, |
2176 | symbolTable))) |
2177 | return failure(); |
2178 | } |
2179 | return success(); |
2180 | } |
2181 | |
2182 | LogicalResult GlobalDtorsOp::verify() { |
2183 | if (getDtors().size() != getPriorities().size()) |
2184 | return emitError( |
2185 | "mismatch between the number of dtors and the number of priorities" ); |
2186 | return success(); |
2187 | } |
2188 | |
2189 | //===----------------------------------------------------------------------===// |
2190 | // ShuffleVectorOp |
2191 | //===----------------------------------------------------------------------===// |
2192 | |
2193 | void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1, |
2194 | Value v2, DenseI32ArrayAttr mask, |
2195 | ArrayRef<NamedAttribute> attrs) { |
2196 | auto containerType = v1.getType(); |
2197 | auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType), |
2198 | mask.size(), |
2199 | LLVM::isScalableVectorType(containerType)); |
2200 | build(builder, state, vType, v1, v2, mask); |
2201 | state.addAttributes(attrs); |
2202 | } |
2203 | |
2204 | void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1, |
2205 | Value v2, ArrayRef<int32_t> mask) { |
2206 | build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask)); |
2207 | } |
2208 | |
2209 | /// Build the result type of a shuffle vector operation. |
2210 | static ParseResult parseShuffleType(AsmParser &parser, Type v1Type, |
2211 | Type &resType, DenseI32ArrayAttr mask) { |
2212 | if (!LLVM::isCompatibleVectorType(type: v1Type)) |
2213 | return parser.emitError(loc: parser.getCurrentLocation(), |
2214 | message: "expected an LLVM compatible vector type" ); |
2215 | resType = LLVM::getVectorType(LLVM::getVectorElementType(type: v1Type), mask.size(), |
2216 | LLVM::isScalableVectorType(vectorType: v1Type)); |
2217 | return success(); |
2218 | } |
2219 | |
2220 | /// Nothing to do when the result type is inferred. |
2221 | static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type, |
2222 | Type resType, DenseI32ArrayAttr mask) {} |
2223 | |
2224 | LogicalResult ShuffleVectorOp::verify() { |
2225 | if (LLVM::isScalableVectorType(getV1().getType()) && |
2226 | llvm::any_of(getMask(), [](int32_t v) { return v != 0; })) |
2227 | return emitOpError("expected a splat operation for scalable vectors" ); |
2228 | return success(); |
2229 | } |
2230 | |
2231 | //===----------------------------------------------------------------------===// |
2232 | // Implementations for LLVM::LLVMFuncOp. |
2233 | //===----------------------------------------------------------------------===// |
2234 | |
2235 | // Add the entry block to the function. |
2236 | Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) { |
2237 | assert(empty() && "function already has an entry block" ); |
2238 | OpBuilder::InsertionGuard g(builder); |
2239 | Block *entry = builder.createBlock(&getBody()); |
2240 | |
2241 | // FIXME: Allow passing in proper locations for the entry arguments. |
2242 | LLVMFunctionType type = getFunctionType(); |
2243 | for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) |
2244 | entry->addArgument(type.getParamType(i), getLoc()); |
2245 | return entry; |
2246 | } |
2247 | |
2248 | void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, |
2249 | StringRef name, Type type, LLVM::Linkage linkage, |
2250 | bool dsoLocal, CConv cconv, SymbolRefAttr comdat, |
2251 | ArrayRef<NamedAttribute> attrs, |
2252 | ArrayRef<DictionaryAttr> argAttrs, |
2253 | std::optional<uint64_t> functionEntryCount) { |
2254 | result.addRegion(); |
2255 | result.addAttribute(SymbolTable::getSymbolAttrName(), |
2256 | builder.getStringAttr(name)); |
2257 | result.addAttribute(getFunctionTypeAttrName(result.name), |
2258 | TypeAttr::get(type)); |
2259 | result.addAttribute(getLinkageAttrName(result.name), |
2260 | LinkageAttr::get(builder.getContext(), linkage)); |
2261 | result.addAttribute(getCConvAttrName(result.name), |
2262 | CConvAttr::get(builder.getContext(), cconv)); |
2263 | result.attributes.append(attrs.begin(), attrs.end()); |
2264 | if (dsoLocal) |
2265 | result.addAttribute(getDsoLocalAttrName(result.name), |
2266 | builder.getUnitAttr()); |
2267 | if (comdat) |
2268 | result.addAttribute(getComdatAttrName(result.name), comdat); |
2269 | if (functionEntryCount) |
2270 | result.addAttribute(getFunctionEntryCountAttrName(result.name), |
2271 | builder.getI64IntegerAttr(functionEntryCount.value())); |
2272 | if (argAttrs.empty()) |
2273 | return; |
2274 | |
2275 | assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() && |
2276 | "expected as many argument attribute lists as arguments" ); |
2277 | function_interface_impl::addArgAndResultAttrs( |
2278 | builder, result, argAttrs, /*resultAttrs=*/std::nullopt, |
2279 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
2280 | } |
2281 | |
2282 | // Builds an LLVM function type from the given lists of input and output types. |
2283 | // Returns a null type if any of the types provided are non-LLVM types, or if |
2284 | // there is more than one output type. |
2285 | static Type |
2286 | buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, |
2287 | ArrayRef<Type> outputs, |
2288 | function_interface_impl::VariadicFlag variadicFlag) { |
2289 | Builder &b = parser.getBuilder(); |
2290 | if (outputs.size() > 1) { |
2291 | parser.emitError(loc, message: "failed to construct function type: expected zero or " |
2292 | "one function result" ); |
2293 | return {}; |
2294 | } |
2295 | |
2296 | // Convert inputs to LLVM types, exit early on error. |
2297 | SmallVector<Type, 4> llvmInputs; |
2298 | for (auto t : inputs) { |
2299 | if (!isCompatibleType(type: t)) { |
2300 | parser.emitError(loc, message: "failed to construct function type: expected LLVM " |
2301 | "type for function arguments" ); |
2302 | return {}; |
2303 | } |
2304 | llvmInputs.push_back(Elt: t); |
2305 | } |
2306 | |
2307 | // No output is denoted as "void" in LLVM type system. |
2308 | Type llvmOutput = |
2309 | outputs.empty() ? LLVMVoidType::get(ctx: b.getContext()) : outputs.front(); |
2310 | if (!isCompatibleType(type: llvmOutput)) { |
2311 | parser.emitError(loc, message: "failed to construct function type: expected LLVM " |
2312 | "type for function results" ) |
2313 | << llvmOutput; |
2314 | return {}; |
2315 | } |
2316 | return LLVMFunctionType::get(llvmOutput, llvmInputs, |
2317 | variadicFlag.isVariadic()); |
2318 | } |
2319 | |
2320 | // Parses an LLVM function. |
2321 | // |
2322 | // operation ::= `llvm.func` linkage? cconv? function-signature |
2323 | // (`comdat(` symbol-ref-id `)`)? |
2324 | // function-attributes? |
2325 | // function-body |
2326 | // |
2327 | ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { |
2328 | // Default to external linkage if no keyword is provided. |
2329 | result.addAttribute( |
2330 | getLinkageAttrName(result.name), |
2331 | LinkageAttr::get(parser.getContext(), |
2332 | parseOptionalLLVMKeyword<Linkage>( |
2333 | parser, result, LLVM::Linkage::External))); |
2334 | |
2335 | // Parse optional visibility, default to Default. |
2336 | result.addAttribute(getVisibility_AttrName(result.name), |
2337 | parser.getBuilder().getI64IntegerAttr( |
2338 | parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>( |
2339 | parser, result, LLVM::Visibility::Default))); |
2340 | |
2341 | // Parse optional UnnamedAddr, default to None. |
2342 | result.addAttribute(getUnnamedAddrAttrName(result.name), |
2343 | parser.getBuilder().getI64IntegerAttr( |
2344 | parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( |
2345 | parser, result, LLVM::UnnamedAddr::None))); |
2346 | |
2347 | // Default to C Calling Convention if no keyword is provided. |
2348 | result.addAttribute( |
2349 | getCConvAttrName(result.name), |
2350 | CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( |
2351 | parser, result, LLVM::CConv::C))); |
2352 | |
2353 | StringAttr nameAttr; |
2354 | SmallVector<OpAsmParser::Argument> entryArgs; |
2355 | SmallVector<DictionaryAttr> resultAttrs; |
2356 | SmallVector<Type> resultTypes; |
2357 | bool isVariadic; |
2358 | |
2359 | auto signatureLocation = parser.getCurrentLocation(); |
2360 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
2361 | result.attributes) || |
2362 | function_interface_impl::parseFunctionSignature( |
2363 | parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, |
2364 | resultAttrs)) |
2365 | return failure(); |
2366 | |
2367 | SmallVector<Type> argTypes; |
2368 | for (auto &arg : entryArgs) |
2369 | argTypes.push_back(arg.type); |
2370 | auto type = |
2371 | buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, |
2372 | function_interface_impl::VariadicFlag(isVariadic)); |
2373 | if (!type) |
2374 | return failure(); |
2375 | result.addAttribute(getFunctionTypeAttrName(result.name), |
2376 | TypeAttr::get(type)); |
2377 | |
2378 | if (succeeded(parser.parseOptionalKeyword("vscale_range" ))) { |
2379 | int64_t minRange, maxRange; |
2380 | if (parser.parseLParen() || parser.parseInteger(minRange) || |
2381 | parser.parseComma() || parser.parseInteger(maxRange) || |
2382 | parser.parseRParen()) |
2383 | return failure(); |
2384 | auto intTy = IntegerType::get(parser.getContext(), 32); |
2385 | result.addAttribute( |
2386 | getVscaleRangeAttrName(result.name), |
2387 | LLVM::VScaleRangeAttr::get(parser.getContext(), |
2388 | IntegerAttr::get(intTy, minRange), |
2389 | IntegerAttr::get(intTy, maxRange))); |
2390 | } |
2391 | // Parse the optional comdat selector. |
2392 | if (succeeded(parser.parseOptionalKeyword("comdat" ))) { |
2393 | SymbolRefAttr comdat; |
2394 | if (parser.parseLParen() || parser.parseAttribute(comdat) || |
2395 | parser.parseRParen()) |
2396 | return failure(); |
2397 | |
2398 | result.addAttribute(getComdatAttrName(result.name), comdat); |
2399 | } |
2400 | |
2401 | if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) |
2402 | return failure(); |
2403 | function_interface_impl::addArgAndResultAttrs( |
2404 | parser.getBuilder(), result, entryArgs, resultAttrs, |
2405 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
2406 | |
2407 | auto *body = result.addRegion(); |
2408 | OptionalParseResult parseResult = |
2409 | parser.parseOptionalRegion(*body, entryArgs); |
2410 | return failure(parseResult.has_value() && failed(*parseResult)); |
2411 | } |
2412 | |
2413 | // Print the LLVMFuncOp. Collects argument and result types and passes them to |
2414 | // helper functions. Drops "void" result since it cannot be parsed back. Skips |
2415 | // the external linkage since it is the default value. |
2416 | void LLVMFuncOp::print(OpAsmPrinter &p) { |
2417 | p << ' '; |
2418 | if (getLinkage() != LLVM::Linkage::External) |
2419 | p << stringifyLinkage(getLinkage()) << ' '; |
2420 | StringRef visibility = stringifyVisibility(getVisibility_()); |
2421 | if (!visibility.empty()) |
2422 | p << visibility << ' '; |
2423 | if (auto unnamedAddr = getUnnamedAddr()) { |
2424 | StringRef str = stringifyUnnamedAddr(*unnamedAddr); |
2425 | if (!str.empty()) |
2426 | p << str << ' '; |
2427 | } |
2428 | if (getCConv() != LLVM::CConv::C) |
2429 | p << stringifyCConv(getCConv()) << ' '; |
2430 | |
2431 | p.printSymbolName(getName()); |
2432 | |
2433 | LLVMFunctionType fnType = getFunctionType(); |
2434 | SmallVector<Type, 8> argTypes; |
2435 | SmallVector<Type, 1> resTypes; |
2436 | argTypes.reserve(fnType.getNumParams()); |
2437 | for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) |
2438 | argTypes.push_back(fnType.getParamType(i)); |
2439 | |
2440 | Type returnType = fnType.getReturnType(); |
2441 | if (!llvm::isa<LLVMVoidType>(returnType)) |
2442 | resTypes.push_back(returnType); |
2443 | |
2444 | function_interface_impl::printFunctionSignature(p, *this, argTypes, |
2445 | isVarArg(), resTypes); |
2446 | |
2447 | // Print vscale range if present |
2448 | if (std::optional<VScaleRangeAttr> vscale = getVscaleRange()) |
2449 | p << " vscale_range(" << vscale->getMinRange().getInt() << ", " |
2450 | << vscale->getMaxRange().getInt() << ')'; |
2451 | |
2452 | // Print the optional comdat selector. |
2453 | if (auto comdat = getComdat()) |
2454 | p << " comdat(" << *comdat << ')'; |
2455 | |
2456 | function_interface_impl::printFunctionAttributes( |
2457 | p, *this, |
2458 | {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), |
2459 | getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(), |
2460 | getComdatAttrName(), getUnnamedAddrAttrName(), |
2461 | getVscaleRangeAttrName()}); |
2462 | |
2463 | // Print the body if this is not an external function. |
2464 | Region &body = getBody(); |
2465 | if (!body.empty()) { |
2466 | p << ' '; |
2467 | p.printRegion(body, /*printEntryBlockArgs=*/false, |
2468 | /*printBlockTerminators=*/true); |
2469 | } |
2470 | } |
2471 | |
2472 | // Verifies LLVM- and implementation-specific properties of the LLVM func Op: |
2473 | // - functions don't have 'common' linkage |
2474 | // - external functions have 'external' or 'extern_weak' linkage; |
2475 | // - vararg is (currently) only supported for external functions; |
2476 | LogicalResult LLVMFuncOp::verify() { |
2477 | if (getLinkage() == LLVM::Linkage::Common) |
2478 | return emitOpError() << "functions cannot have '" |
2479 | << stringifyLinkage(LLVM::Linkage::Common) |
2480 | << "' linkage" ; |
2481 | |
2482 | if (failed(verifyComdat(*this, getComdat()))) |
2483 | return failure(); |
2484 | |
2485 | if (isExternal()) { |
2486 | if (getLinkage() != LLVM::Linkage::External && |
2487 | getLinkage() != LLVM::Linkage::ExternWeak) |
2488 | return emitOpError() << "external functions must have '" |
2489 | << stringifyLinkage(LLVM::Linkage::External) |
2490 | << "' or '" |
2491 | << stringifyLinkage(LLVM::Linkage::ExternWeak) |
2492 | << "' linkage" ; |
2493 | return success(); |
2494 | } |
2495 | |
2496 | Type landingpadResultTy; |
2497 | StringRef diagnosticMessage; |
2498 | bool isLandingpadTypeConsistent = |
2499 | !walk([&](Operation *op) { |
2500 | const auto checkType = [&](Type type, StringRef errorMessage) { |
2501 | if (!landingpadResultTy) { |
2502 | landingpadResultTy = type; |
2503 | return WalkResult::advance(); |
2504 | } |
2505 | if (landingpadResultTy != type) { |
2506 | diagnosticMessage = errorMessage; |
2507 | return WalkResult::interrupt(); |
2508 | } |
2509 | return WalkResult::advance(); |
2510 | }; |
2511 | return TypeSwitch<Operation *, WalkResult>(op) |
2512 | .Case<LandingpadOp>([&](auto landingpad) { |
2513 | constexpr StringLiteral errorMessage = |
2514 | "'llvm.landingpad' should have a consistent result type " |
2515 | "inside a function" ; |
2516 | return checkType(landingpad.getType(), errorMessage); |
2517 | }) |
2518 | .Case<ResumeOp>([&](auto resume) { |
2519 | constexpr StringLiteral errorMessage = |
2520 | "'llvm.resume' should have a consistent input type inside a " |
2521 | "function" ; |
2522 | return checkType(resume.getValue().getType(), errorMessage); |
2523 | }) |
2524 | .Default([](auto) { return WalkResult::skip(); }); |
2525 | }).wasInterrupted(); |
2526 | if (!isLandingpadTypeConsistent) { |
2527 | assert(!diagnosticMessage.empty() && |
2528 | "Expecting a non-empty diagnostic message" ); |
2529 | return emitError(diagnosticMessage); |
2530 | } |
2531 | |
2532 | return success(); |
2533 | } |
2534 | |
2535 | /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: |
2536 | /// - entry block arguments are of LLVM types. |
2537 | LogicalResult LLVMFuncOp::verifyRegions() { |
2538 | if (isExternal()) |
2539 | return success(); |
2540 | |
2541 | unsigned numArguments = getFunctionType().getNumParams(); |
2542 | Block &entryBlock = front(); |
2543 | for (unsigned i = 0; i < numArguments; ++i) { |
2544 | Type argType = entryBlock.getArgument(i).getType(); |
2545 | if (!isCompatibleType(argType)) |
2546 | return emitOpError("entry block argument #" ) |
2547 | << i << " is not of LLVM type" ; |
2548 | } |
2549 | |
2550 | return success(); |
2551 | } |
2552 | |
2553 | Region *LLVMFuncOp::getCallableRegion() { |
2554 | if (isExternal()) |
2555 | return nullptr; |
2556 | return &getBody(); |
2557 | } |
2558 | |
2559 | //===----------------------------------------------------------------------===// |
2560 | // ZeroOp. |
2561 | //===----------------------------------------------------------------------===// |
2562 | |
2563 | LogicalResult LLVM::ZeroOp::verify() { |
2564 | if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) |
2565 | if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit)) |
2566 | return emitOpError() |
2567 | << "target extension type does not support zero-initializer" ; |
2568 | |
2569 | return success(); |
2570 | } |
2571 | |
2572 | //===----------------------------------------------------------------------===// |
2573 | // ConstantOp. |
2574 | //===----------------------------------------------------------------------===// |
2575 | |
2576 | LogicalResult LLVM::ConstantOp::verify() { |
2577 | if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) { |
2578 | auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType()); |
2579 | if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || |
2580 | !arrayType.getElementType().isInteger(8)) { |
2581 | return emitOpError() << "expected array type of " |
2582 | << sAttr.getValue().size() |
2583 | << " i8 elements for the string constant" ; |
2584 | } |
2585 | return success(); |
2586 | } |
2587 | if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) { |
2588 | if (structType.getBody().size() != 2 || |
2589 | structType.getBody()[0] != structType.getBody()[1]) { |
2590 | return emitError() << "expected struct type with two elements of the " |
2591 | "same type, the type of a complex constant" ; |
2592 | } |
2593 | |
2594 | auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue()); |
2595 | if (!arrayAttr || arrayAttr.size() != 2) { |
2596 | return emitOpError() << "expected array attribute with two elements, " |
2597 | "representing a complex constant" ; |
2598 | } |
2599 | auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]); |
2600 | auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]); |
2601 | if (!re || !im || re.getType() != im.getType()) { |
2602 | return emitOpError() |
2603 | << "expected array attribute with two elements of the same type" ; |
2604 | } |
2605 | |
2606 | Type elementType = structType.getBody()[0]; |
2607 | if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>( |
2608 | elementType)) { |
2609 | return emitError() |
2610 | << "expected struct element types to be floating point type or " |
2611 | "integer type" ; |
2612 | } |
2613 | return success(); |
2614 | } |
2615 | if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) { |
2616 | return emitOpError() << "does not support target extension type." ; |
2617 | } |
2618 | if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue())) |
2619 | return emitOpError() |
2620 | << "only supports integer, float, string or elements attributes" ; |
2621 | if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) { |
2622 | if (!llvm::isa<IntegerType>(getType())) |
2623 | return emitOpError() << "expected integer type" ; |
2624 | } |
2625 | if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) { |
2626 | const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); |
2627 | unsigned floatWidth = APFloat::getSizeInBits(sem); |
2628 | if (auto floatTy = dyn_cast<FloatType>(getType())) { |
2629 | if (floatTy.getWidth() != floatWidth) { |
2630 | return emitOpError() << "expected float type of width " << floatWidth; |
2631 | } |
2632 | } |
2633 | // See the comment for getLLVMConstant for more details about why 8-bit |
2634 | // floats can be represented by integers. |
2635 | if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) { |
2636 | return emitOpError() << "expected integer type of width " << floatWidth; |
2637 | } |
2638 | } |
2639 | if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) { |
2640 | if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) && |
2641 | !isa<LLVM::LLVMFixedVectorType>(getType()) && |
2642 | !isa<LLVM::LLVMScalableVectorType>(getType())) |
2643 | return emitOpError() << "expected vector or array type" ; |
2644 | } |
2645 | return success(); |
2646 | } |
2647 | |
2648 | bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) { |
2649 | // The value's type must be the same as the provided type. |
2650 | auto typedAttr = dyn_cast<TypedAttr>(value); |
2651 | if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type)) |
2652 | return false; |
2653 | // The value's type must be an LLVM compatible type. |
2654 | if (!isCompatibleType(type)) |
2655 | return false; |
2656 | // TODO: Add support for additional attributes kinds once needed. |
2657 | return isa<IntegerAttr, FloatAttr, ElementsAttr>(value); |
2658 | } |
2659 | |
2660 | ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value, |
2661 | Type type, Location loc) { |
2662 | if (isBuildableWith(value, type)) |
2663 | return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value)); |
2664 | return nullptr; |
2665 | } |
2666 | |
2667 | // Constant op constant-folds to its value. |
2668 | OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); } |
2669 | |
2670 | //===----------------------------------------------------------------------===// |
2671 | // AtomicRMWOp |
2672 | //===----------------------------------------------------------------------===// |
2673 | |
2674 | void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, |
2675 | AtomicBinOp binOp, Value ptr, Value val, |
2676 | AtomicOrdering ordering, StringRef syncscope, |
2677 | unsigned alignment, bool isVolatile) { |
2678 | build(builder, state, val.getType(), binOp, ptr, val, ordering, |
2679 | !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr, |
2680 | alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, |
2681 | /*access_groups=*/nullptr, |
2682 | /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
2683 | } |
2684 | |
2685 | LogicalResult AtomicRMWOp::verify() { |
2686 | auto valType = getVal().getType(); |
2687 | if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub || |
2688 | getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) { |
2689 | if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) |
2690 | return emitOpError("expected LLVM IR floating point type" ); |
2691 | } else if (getBinOp() == AtomicBinOp::xchg) { |
2692 | if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/true)) |
2693 | return emitOpError("unexpected LLVM IR type for 'xchg' bin_op" ); |
2694 | } else { |
2695 | auto intType = llvm::dyn_cast<IntegerType>(valType); |
2696 | unsigned intBitWidth = intType ? intType.getWidth() : 0; |
2697 | if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && |
2698 | intBitWidth != 64) |
2699 | return emitOpError("expected LLVM IR integer type" ); |
2700 | } |
2701 | |
2702 | if (static_cast<unsigned>(getOrdering()) < |
2703 | static_cast<unsigned>(AtomicOrdering::monotonic)) |
2704 | return emitOpError() << "expected at least '" |
2705 | << stringifyAtomicOrdering(AtomicOrdering::monotonic) |
2706 | << "' ordering" ; |
2707 | |
2708 | return success(); |
2709 | } |
2710 | |
2711 | //===----------------------------------------------------------------------===// |
2712 | // AtomicCmpXchgOp |
2713 | //===----------------------------------------------------------------------===// |
2714 | |
2715 | /// Returns an LLVM struct type that contains a value type and a boolean type. |
2716 | static LLVMStructType getValAndBoolStructType(Type valType) { |
2717 | auto boolType = IntegerType::get(valType.getContext(), 1); |
2718 | return LLVMStructType::getLiteral(context: valType.getContext(), types: {valType, boolType}); |
2719 | } |
2720 | |
2721 | void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state, |
2722 | Value ptr, Value cmp, Value val, |
2723 | AtomicOrdering successOrdering, |
2724 | AtomicOrdering failureOrdering, StringRef syncscope, |
2725 | unsigned alignment, bool isWeak, bool isVolatile) { |
2726 | build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val, |
2727 | successOrdering, failureOrdering, |
2728 | !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr, |
2729 | alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak, |
2730 | isVolatile, /*access_groups=*/nullptr, |
2731 | /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
2732 | } |
2733 | |
2734 | LogicalResult AtomicCmpXchgOp::verify() { |
2735 | auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()); |
2736 | if (!ptrType) |
2737 | return emitOpError("expected LLVM IR pointer type for operand #0" ); |
2738 | auto valType = getVal().getType(); |
2739 | if (!isTypeCompatibleWithAtomicOp(valType, |
2740 | /*isPointerTypeAllowed=*/true)) |
2741 | return emitOpError("unexpected LLVM IR type" ); |
2742 | if (getSuccessOrdering() < AtomicOrdering::monotonic || |
2743 | getFailureOrdering() < AtomicOrdering::monotonic) |
2744 | return emitOpError("ordering must be at least 'monotonic'" ); |
2745 | if (getFailureOrdering() == AtomicOrdering::release || |
2746 | getFailureOrdering() == AtomicOrdering::acq_rel) |
2747 | return emitOpError("failure ordering cannot be 'release' or 'acq_rel'" ); |
2748 | return success(); |
2749 | } |
2750 | |
2751 | //===----------------------------------------------------------------------===// |
2752 | // FenceOp |
2753 | //===----------------------------------------------------------------------===// |
2754 | |
2755 | void FenceOp::build(OpBuilder &builder, OperationState &state, |
2756 | AtomicOrdering ordering, StringRef syncscope) { |
2757 | build(builder, state, ordering, |
2758 | syncscope.empty() ? nullptr : builder.getStringAttr(syncscope)); |
2759 | } |
2760 | |
2761 | LogicalResult FenceOp::verify() { |
2762 | if (getOrdering() == AtomicOrdering::not_atomic || |
2763 | getOrdering() == AtomicOrdering::unordered || |
2764 | getOrdering() == AtomicOrdering::monotonic) |
2765 | return emitOpError("can be given only acquire, release, acq_rel, " |
2766 | "and seq_cst orderings" ); |
2767 | return success(); |
2768 | } |
2769 | |
2770 | //===----------------------------------------------------------------------===// |
2771 | // Verifier for extension ops |
2772 | //===----------------------------------------------------------------------===// |
2773 | |
2774 | /// Verifies that the given extension operation operates on consistent scalars |
2775 | /// or vectors, and that the target width is larger than the input width. |
2776 | template <class ExtOp> |
2777 | static LogicalResult verifyExtOp(ExtOp op) { |
2778 | IntegerType inputType, outputType; |
2779 | if (isCompatibleVectorType(op.getArg().getType())) { |
2780 | if (!isCompatibleVectorType(op.getResult().getType())) |
2781 | return op.emitError( |
2782 | "input type is a vector but output type is an integer" ); |
2783 | if (getVectorNumElements(op.getArg().getType()) != |
2784 | getVectorNumElements(op.getResult().getType())) |
2785 | return op.emitError("input and output vectors are of incompatible shape" ); |
2786 | // Because this is a CastOp, the element of vectors is guaranteed to be an |
2787 | // integer. |
2788 | inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType())); |
2789 | outputType = |
2790 | cast<IntegerType>(getVectorElementType(op.getResult().getType())); |
2791 | } else { |
2792 | // Because this is a CastOp and arg is not a vector, arg is guaranteed to be |
2793 | // an integer. |
2794 | inputType = cast<IntegerType>(op.getArg().getType()); |
2795 | outputType = dyn_cast<IntegerType>(op.getResult().getType()); |
2796 | if (!outputType) |
2797 | return op.emitError( |
2798 | "input type is an integer but output type is a vector" ); |
2799 | } |
2800 | |
2801 | if (outputType.getWidth() <= inputType.getWidth()) |
2802 | return op.emitError("integer width of the output type is smaller or " |
2803 | "equal to the integer width of the input type" ); |
2804 | return success(); |
2805 | } |
2806 | |
2807 | //===----------------------------------------------------------------------===// |
2808 | // ZExtOp |
2809 | //===----------------------------------------------------------------------===// |
2810 | |
2811 | LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); } |
2812 | |
2813 | OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) { |
2814 | auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg()); |
2815 | if (!arg) |
2816 | return {}; |
2817 | |
2818 | size_t targetSize = cast<IntegerType>(getType()).getWidth(); |
2819 | return IntegerAttr::get(getType(), arg.getValue().zext(targetSize)); |
2820 | } |
2821 | |
2822 | //===----------------------------------------------------------------------===// |
2823 | // SExtOp |
2824 | //===----------------------------------------------------------------------===// |
2825 | |
2826 | LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); } |
2827 | |
2828 | //===----------------------------------------------------------------------===// |
2829 | // Folder and verifier for LLVM::BitcastOp |
2830 | //===----------------------------------------------------------------------===// |
2831 | |
2832 | /// Folds a cast op that can be chained. |
2833 | template <typename T> |
2834 | static OpFoldResult foldChainableCast(T castOp, |
2835 | typename T::FoldAdaptor adaptor) { |
2836 | // cast(x : T0, T0) -> x |
2837 | if (castOp.getArg().getType() == castOp.getType()) |
2838 | return castOp.getArg(); |
2839 | if (auto prev = castOp.getArg().template getDefiningOp<T>()) { |
2840 | // cast(cast(x : T0, T1), T0) -> x |
2841 | if (prev.getArg().getType() == castOp.getType()) |
2842 | return prev.getArg(); |
2843 | // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2) |
2844 | castOp.getArgMutable().set(prev.getArg()); |
2845 | return Value{castOp}; |
2846 | } |
2847 | return {}; |
2848 | } |
2849 | |
2850 | OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) { |
2851 | return foldChainableCast(*this, adaptor); |
2852 | } |
2853 | |
2854 | LogicalResult LLVM::BitcastOp::verify() { |
2855 | auto resultType = llvm::dyn_cast<LLVMPointerType>( |
2856 | extractVectorElementType(getResult().getType())); |
2857 | auto sourceType = llvm::dyn_cast<LLVMPointerType>( |
2858 | extractVectorElementType(getArg().getType())); |
2859 | |
2860 | // If one of the types is a pointer (or vector of pointers), then |
2861 | // both source and result type have to be pointers. |
2862 | if (static_cast<bool>(resultType) != static_cast<bool>(sourceType)) |
2863 | return emitOpError("can only cast pointers from and to pointers" ); |
2864 | |
2865 | if (!resultType) |
2866 | return success(); |
2867 | |
2868 | auto isVector = |
2869 | llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>; |
2870 | |
2871 | // Due to bitcast requiring both operands to be of the same size, it is not |
2872 | // possible for only one of the two to be a pointer of vectors. |
2873 | if (isVector(getResult().getType()) && !isVector(getArg().getType())) |
2874 | return emitOpError("cannot cast pointer to vector of pointers" ); |
2875 | |
2876 | if (!isVector(getResult().getType()) && isVector(getArg().getType())) |
2877 | return emitOpError("cannot cast vector of pointers to pointer" ); |
2878 | |
2879 | // Bitcast cannot cast between pointers of different address spaces. |
2880 | // 'llvm.addrspacecast' must be used for this purpose instead. |
2881 | if (resultType.getAddressSpace() != sourceType.getAddressSpace()) |
2882 | return emitOpError("cannot cast pointers of different address spaces, " |
2883 | "use 'llvm.addrspacecast' instead" ); |
2884 | |
2885 | return success(); |
2886 | } |
2887 | |
2888 | //===----------------------------------------------------------------------===// |
2889 | // Folder for LLVM::AddrSpaceCastOp |
2890 | //===----------------------------------------------------------------------===// |
2891 | |
2892 | OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) { |
2893 | return foldChainableCast(*this, adaptor); |
2894 | } |
2895 | |
2896 | //===----------------------------------------------------------------------===// |
2897 | // Folder for LLVM::GEPOp |
2898 | //===----------------------------------------------------------------------===// |
2899 | |
2900 | OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) { |
2901 | GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(), |
2902 | adaptor.getDynamicIndices()); |
2903 | |
2904 | // gep %x:T, 0 -> %x |
2905 | if (getBase().getType() == getType() && indices.size() == 1) |
2906 | if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0])) |
2907 | if (integer.getValue().isZero()) |
2908 | return getBase(); |
2909 | |
2910 | // Canonicalize any dynamic indices of constant value to constant indices. |
2911 | bool changed = false; |
2912 | SmallVector<GEPArg> gepArgs; |
2913 | for (auto iter : llvm::enumerate(indices)) { |
2914 | auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value()); |
2915 | // Constant indices can only be int32_t, so if integer does not fit we |
2916 | // are forced to keep it dynamic, despite being a constant. |
2917 | if (!indices.isDynamicIndex(iter.index()) || !integer || |
2918 | !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) { |
2919 | |
2920 | PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()]; |
2921 | if (Value val = llvm::dyn_cast_if_present<Value>(existing)) |
2922 | gepArgs.emplace_back(val); |
2923 | else |
2924 | gepArgs.emplace_back(existing.get<IntegerAttr>().getInt()); |
2925 | |
2926 | continue; |
2927 | } |
2928 | |
2929 | changed = true; |
2930 | gepArgs.emplace_back(integer.getInt()); |
2931 | } |
2932 | if (changed) { |
2933 | SmallVector<int32_t> rawConstantIndices; |
2934 | SmallVector<Value> dynamicIndices; |
2935 | destructureIndices(getElemType(), gepArgs, rawConstantIndices, |
2936 | dynamicIndices); |
2937 | |
2938 | getDynamicIndicesMutable().assign(dynamicIndices); |
2939 | setRawConstantIndices(rawConstantIndices); |
2940 | return Value{*this}; |
2941 | } |
2942 | |
2943 | return {}; |
2944 | } |
2945 | |
2946 | //===----------------------------------------------------------------------===// |
2947 | // ShlOp |
2948 | //===----------------------------------------------------------------------===// |
2949 | |
2950 | OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) { |
2951 | auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()); |
2952 | if (!rhs) |
2953 | return {}; |
2954 | |
2955 | if (rhs.getValue().getZExtValue() >= |
2956 | getLhs().getType().getIntOrFloatBitWidth()) |
2957 | return {}; // TODO: Fold into poison. |
2958 | |
2959 | auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs()); |
2960 | if (!lhs) |
2961 | return {}; |
2962 | |
2963 | return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue())); |
2964 | } |
2965 | |
2966 | //===----------------------------------------------------------------------===// |
2967 | // OrOp |
2968 | //===----------------------------------------------------------------------===// |
2969 | |
2970 | OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) { |
2971 | auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs()); |
2972 | if (!lhs) |
2973 | return {}; |
2974 | |
2975 | auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()); |
2976 | if (!rhs) |
2977 | return {}; |
2978 | |
2979 | return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue()); |
2980 | } |
2981 | |
2982 | //===----------------------------------------------------------------------===// |
2983 | // CallIntrinsicOp |
2984 | //===----------------------------------------------------------------------===// |
2985 | |
2986 | LogicalResult CallIntrinsicOp::verify() { |
2987 | if (!getIntrin().starts_with("llvm." )) |
2988 | return emitOpError() << "intrinsic name must start with 'llvm.'" ; |
2989 | return success(); |
2990 | } |
2991 | |
2992 | //===----------------------------------------------------------------------===// |
2993 | // OpAsmDialectInterface |
2994 | //===----------------------------------------------------------------------===// |
2995 | |
2996 | namespace { |
2997 | struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface { |
2998 | using OpAsmDialectInterface::OpAsmDialectInterface; |
2999 | |
3000 | AliasResult getAlias(Attribute attr, raw_ostream &os) const override { |
3001 | return TypeSwitch<Attribute, AliasResult>(attr) |
3002 | .Case<AccessGroupAttr, AliasScopeAttr, AliasScopeDomainAttr, |
3003 | DIBasicTypeAttr, DICompileUnitAttr, DICompositeTypeAttr, |
3004 | DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr, |
3005 | DIGlobalVariableExpressionAttr, DILabelAttr, DILexicalBlockAttr, |
3006 | DILexicalBlockFileAttr, DILocalVariableAttr, DIModuleAttr, |
3007 | DINamespaceAttr, DINullTypeAttr, DISubprogramAttr, |
3008 | DISubroutineTypeAttr, LoopAnnotationAttr, LoopVectorizeAttr, |
3009 | LoopInterleaveAttr, LoopUnrollAttr, LoopUnrollAndJamAttr, |
3010 | LoopLICMAttr, LoopDistributeAttr, LoopPipelineAttr, |
3011 | LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr, TBAATagAttr, |
3012 | TBAATypeDescriptorAttr>([&](auto attr) { |
3013 | os << decltype(attr)::getMnemonic(); |
3014 | return AliasResult::OverridableAlias; |
3015 | }) |
3016 | .Default([](Attribute) { return AliasResult::NoAlias; }); |
3017 | } |
3018 | }; |
3019 | } // namespace |
3020 | |
3021 | //===----------------------------------------------------------------------===// |
3022 | // LinkerOptionsOp |
3023 | //===----------------------------------------------------------------------===// |
3024 | |
3025 | LogicalResult LinkerOptionsOp::verify() { |
3026 | if (mlir::Operation *parentOp = (*this)->getParentOp(); |
3027 | parentOp && !satisfiesLLVMModule(parentOp)) |
3028 | return emitOpError("must appear at the module level" ); |
3029 | return success(); |
3030 | } |
3031 | |
3032 | //===----------------------------------------------------------------------===// |
3033 | // LLVMDialect initialization, type parsing, and registration. |
3034 | //===----------------------------------------------------------------------===// |
3035 | |
3036 | void LLVMDialect::initialize() { |
3037 | registerAttributes(); |
3038 | |
3039 | // clang-format off |
3040 | addTypes<LLVMVoidType, |
3041 | LLVMPPCFP128Type, |
3042 | LLVMX86MMXType, |
3043 | LLVMTokenType, |
3044 | LLVMLabelType, |
3045 | LLVMMetadataType, |
3046 | LLVMStructType>(); |
3047 | // clang-format on |
3048 | registerTypes(); |
3049 | |
3050 | addOperations< |
3051 | #define GET_OP_LIST |
3052 | #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" |
3053 | , |
3054 | #define GET_OP_LIST |
3055 | #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" |
3056 | >(); |
3057 | |
3058 | // Support unknown operations because not all LLVM operations are registered. |
3059 | allowUnknownOperations(); |
3060 | // clang-format off |
3061 | addInterfaces<LLVMOpAsmDialectInterface>(); |
3062 | // clang-format on |
3063 | detail::addLLVMInlinerInterface(this); |
3064 | } |
3065 | |
3066 | #define GET_OP_CLASSES |
3067 | #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" |
3068 | |
3069 | #define GET_OP_CLASSES |
3070 | #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" |
3071 | |
3072 | LogicalResult LLVMDialect::verifyDataLayoutString( |
3073 | StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { |
3074 | llvm::Expected<llvm::DataLayout> maybeDataLayout = |
3075 | llvm::DataLayout::parse(descr); |
3076 | if (maybeDataLayout) |
3077 | return success(); |
3078 | |
3079 | std::string message; |
3080 | llvm::raw_string_ostream messageStream(message); |
3081 | llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); |
3082 | reportError("invalid data layout descriptor: " + messageStream.str()); |
3083 | return failure(); |
3084 | } |
3085 | |
3086 | /// Verify LLVM dialect attributes. |
3087 | LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, |
3088 | NamedAttribute attr) { |
3089 | // If the data layout attribute is present, it must use the LLVM data layout |
3090 | // syntax. Try parsing it and report errors in case of failure. Users of this |
3091 | // attribute may assume it is well-formed and can pass it to the (asserting) |
3092 | // llvm::DataLayout constructor. |
3093 | if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) |
3094 | return success(); |
3095 | if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue())) |
3096 | return verifyDataLayoutString( |
3097 | stringAttr.getValue(), |
3098 | [op](const Twine &message) { op->emitOpError() << message.str(); }); |
3099 | |
3100 | return op->emitOpError() << "expected '" |
3101 | << LLVM::LLVMDialect::getDataLayoutAttrName() |
3102 | << "' to be a string attributes" ; |
3103 | } |
3104 | |
3105 | LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op, |
3106 | Type paramType, |
3107 | NamedAttribute paramAttr) { |
3108 | // LLVM attribute may be attached to a result of operation that has not been |
3109 | // converted to LLVM dialect yet, so the result may have a type with unknown |
3110 | // representation in LLVM dialect type space. In this case we cannot verify |
3111 | // whether the attribute may be |
3112 | bool verifyValueType = isCompatibleType(paramType); |
3113 | StringAttr name = paramAttr.getName(); |
3114 | |
3115 | auto checkUnitAttrType = [&]() -> LogicalResult { |
3116 | if (!llvm::isa<UnitAttr>(paramAttr.getValue())) |
3117 | return op->emitError() << name << " should be a unit attribute" ; |
3118 | return success(); |
3119 | }; |
3120 | auto checkTypeAttrType = [&]() -> LogicalResult { |
3121 | if (!llvm::isa<TypeAttr>(paramAttr.getValue())) |
3122 | return op->emitError() << name << " should be a type attribute" ; |
3123 | return success(); |
3124 | }; |
3125 | auto checkIntegerAttrType = [&]() -> LogicalResult { |
3126 | if (!llvm::isa<IntegerAttr>(paramAttr.getValue())) |
3127 | return op->emitError() << name << " should be an integer attribute" ; |
3128 | return success(); |
3129 | }; |
3130 | auto checkPointerType = [&]() -> LogicalResult { |
3131 | if (!llvm::isa<LLVMPointerType>(paramType)) |
3132 | return op->emitError() |
3133 | << name << " attribute attached to non-pointer LLVM type" ; |
3134 | return success(); |
3135 | }; |
3136 | auto checkIntegerType = [&]() -> LogicalResult { |
3137 | if (!llvm::isa<IntegerType>(paramType)) |
3138 | return op->emitError() |
3139 | << name << " attribute attached to non-integer LLVM type" ; |
3140 | return success(); |
3141 | }; |
3142 | auto checkPointerTypeMatches = [&]() -> LogicalResult { |
3143 | if (failed(checkPointerType())) |
3144 | return failure(); |
3145 | |
3146 | return success(); |
3147 | }; |
3148 | |
3149 | // Check a unit attribute that is attached to a pointer value. |
3150 | if (name == LLVMDialect::getNoAliasAttrName() || |
3151 | name == LLVMDialect::getReadonlyAttrName() || |
3152 | name == LLVMDialect::getReadnoneAttrName() || |
3153 | name == LLVMDialect::getWriteOnlyAttrName() || |
3154 | name == LLVMDialect::getNestAttrName() || |
3155 | name == LLVMDialect::getNoCaptureAttrName() || |
3156 | name == LLVMDialect::getNoFreeAttrName() || |
3157 | name == LLVMDialect::getNonNullAttrName()) { |
3158 | if (failed(checkUnitAttrType())) |
3159 | return failure(); |
3160 | if (verifyValueType && failed(checkPointerType())) |
3161 | return failure(); |
3162 | return success(); |
3163 | } |
3164 | |
3165 | // Check a type attribute that is attached to a pointer value. |
3166 | if (name == LLVMDialect::getStructRetAttrName() || |
3167 | name == LLVMDialect::getByValAttrName() || |
3168 | name == LLVMDialect::getByRefAttrName() || |
3169 | name == LLVMDialect::getInAllocaAttrName() || |
3170 | name == LLVMDialect::getPreallocatedAttrName()) { |
3171 | if (failed(checkTypeAttrType())) |
3172 | return failure(); |
3173 | if (verifyValueType && failed(checkPointerTypeMatches())) |
3174 | return failure(); |
3175 | return success(); |
3176 | } |
3177 | |
3178 | // Check a unit attribute that is attached to an integer value. |
3179 | if (name == LLVMDialect::getSExtAttrName() || |
3180 | name == LLVMDialect::getZExtAttrName()) { |
3181 | if (failed(checkUnitAttrType())) |
3182 | return failure(); |
3183 | if (verifyValueType && failed(checkIntegerType())) |
3184 | return failure(); |
3185 | return success(); |
3186 | } |
3187 | |
3188 | // Check an integer attribute that is attached to a pointer value. |
3189 | if (name == LLVMDialect::getAlignAttrName() || |
3190 | name == LLVMDialect::getDereferenceableAttrName() || |
3191 | name == LLVMDialect::getDereferenceableOrNullAttrName() || |
3192 | name == LLVMDialect::getStackAlignmentAttrName()) { |
3193 | if (failed(checkIntegerAttrType())) |
3194 | return failure(); |
3195 | if (verifyValueType && failed(checkPointerType())) |
3196 | return failure(); |
3197 | return success(); |
3198 | } |
3199 | |
3200 | // Check a unit attribute that can be attached to arbitrary types. |
3201 | if (name == LLVMDialect::getNoUndefAttrName() || |
3202 | name == LLVMDialect::getInRegAttrName() || |
3203 | name == LLVMDialect::getReturnedAttrName()) |
3204 | return checkUnitAttrType(); |
3205 | |
3206 | return success(); |
3207 | } |
3208 | |
3209 | /// Verify LLVMIR function argument attributes. |
3210 | LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, |
3211 | unsigned regionIdx, |
3212 | unsigned argIdx, |
3213 | NamedAttribute argAttr) { |
3214 | auto funcOp = dyn_cast<FunctionOpInterface>(op); |
3215 | if (!funcOp) |
3216 | return success(); |
3217 | Type argType = funcOp.getArgumentTypes()[argIdx]; |
3218 | |
3219 | return verifyParameterAttribute(op, argType, argAttr); |
3220 | } |
3221 | |
3222 | LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, |
3223 | unsigned regionIdx, |
3224 | unsigned resIdx, |
3225 | NamedAttribute resAttr) { |
3226 | auto funcOp = dyn_cast<FunctionOpInterface>(op); |
3227 | if (!funcOp) |
3228 | return success(); |
3229 | Type resType = funcOp.getResultTypes()[resIdx]; |
3230 | |
3231 | // Check to see if this function has a void return with a result attribute |
3232 | // to it. It isn't clear what semantics we would assign to that. |
3233 | if (llvm::isa<LLVMVoidType>(resType)) |
3234 | return op->emitError() << "cannot attach result attributes to functions " |
3235 | "with a void return" ; |
3236 | |
3237 | // Check to see if this attribute is allowed as a result attribute. Only |
3238 | // explicitly forbidden LLVM attributes will cause an error. |
3239 | auto name = resAttr.getName(); |
3240 | if (name == LLVMDialect::getAllocAlignAttrName() || |
3241 | name == LLVMDialect::getAllocatedPointerAttrName() || |
3242 | name == LLVMDialect::getByValAttrName() || |
3243 | name == LLVMDialect::getByRefAttrName() || |
3244 | name == LLVMDialect::getInAllocaAttrName() || |
3245 | name == LLVMDialect::getNestAttrName() || |
3246 | name == LLVMDialect::getNoCaptureAttrName() || |
3247 | name == LLVMDialect::getNoFreeAttrName() || |
3248 | name == LLVMDialect::getPreallocatedAttrName() || |
3249 | name == LLVMDialect::getReadnoneAttrName() || |
3250 | name == LLVMDialect::getReadonlyAttrName() || |
3251 | name == LLVMDialect::getReturnedAttrName() || |
3252 | name == LLVMDialect::getStackAlignmentAttrName() || |
3253 | name == LLVMDialect::getStructRetAttrName() || |
3254 | name == LLVMDialect::getWriteOnlyAttrName()) |
3255 | return op->emitError() << name << " is not a valid result attribute" ; |
3256 | return verifyParameterAttribute(op, resType, resAttr); |
3257 | } |
3258 | |
3259 | Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value, |
3260 | Type type, Location loc) { |
3261 | return LLVM::ConstantOp::materialize(builder, value, type, loc); |
3262 | } |
3263 | |
3264 | //===----------------------------------------------------------------------===// |
3265 | // Utility functions. |
3266 | //===----------------------------------------------------------------------===// |
3267 | |
3268 | Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, |
3269 | StringRef name, StringRef value, |
3270 | LLVM::Linkage linkage) { |
3271 | assert(builder.getInsertionBlock() && |
3272 | builder.getInsertionBlock()->getParentOp() && |
3273 | "expected builder to point to a block constrained in an op" ); |
3274 | auto module = |
3275 | builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); |
3276 | assert(module && "builder points to an op outside of a module" ); |
3277 | |
3278 | // Create the global at the entry of the module. |
3279 | OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); |
3280 | MLIRContext *ctx = builder.getContext(); |
3281 | auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); |
3282 | auto global = moduleBuilder.create<LLVM::GlobalOp>( |
3283 | loc, type, /*isConstant=*/true, linkage, name, |
3284 | builder.getStringAttr(value), /*alignment=*/0); |
3285 | |
3286 | LLVMPointerType ptrType = LLVMPointerType::get(ctx); |
3287 | // Get the pointer to the first character in the global string. |
3288 | Value globalPtr = |
3289 | builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr()); |
3290 | return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr, |
3291 | ArrayRef<GEPArg>{0, 0}); |
3292 | } |
3293 | |
3294 | bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { |
3295 | return op->hasTrait<OpTrait::SymbolTable>() && |
3296 | op->hasTrait<OpTrait::IsIsolatedFromAbove>(); |
3297 | } |
3298 | |