1 | //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the translation between an MLIR LLVM dialect module and |
10 | // the corresponding LLVMIR module. It only handles core LLVM IR operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
15 | |
16 | #include "AttrKindDetail.h" |
17 | #include "DebugTranslation.h" |
18 | #include "LoopAnnotationTranslation.h" |
19 | #include "mlir/Dialect/DLTI/DLTI.h" |
20 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
21 | #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" |
22 | #include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h" |
23 | #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" |
24 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
25 | #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" |
26 | #include "mlir/IR/AttrTypeSubElements.h" |
27 | #include "mlir/IR/Attributes.h" |
28 | #include "mlir/IR/BuiltinOps.h" |
29 | #include "mlir/IR/BuiltinTypes.h" |
30 | #include "mlir/IR/DialectResourceBlobManager.h" |
31 | #include "mlir/IR/RegionGraphTraits.h" |
32 | #include "mlir/Support/LLVM.h" |
33 | #include "mlir/Support/LogicalResult.h" |
34 | #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" |
35 | #include "mlir/Target/LLVMIR/TypeToLLVM.h" |
36 | #include "mlir/Transforms/RegionUtils.h" |
37 | |
38 | #include "llvm/ADT/PostOrderIterator.h" |
39 | #include "llvm/ADT/SetVector.h" |
40 | #include "llvm/ADT/StringExtras.h" |
41 | #include "llvm/ADT/TypeSwitch.h" |
42 | #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" |
43 | #include "llvm/IR/BasicBlock.h" |
44 | #include "llvm/IR/CFG.h" |
45 | #include "llvm/IR/Constants.h" |
46 | #include "llvm/IR/DerivedTypes.h" |
47 | #include "llvm/IR/IRBuilder.h" |
48 | #include "llvm/IR/InlineAsm.h" |
49 | #include "llvm/IR/IntrinsicsNVPTX.h" |
50 | #include "llvm/IR/LLVMContext.h" |
51 | #include "llvm/IR/MDBuilder.h" |
52 | #include "llvm/IR/Module.h" |
53 | #include "llvm/IR/Verifier.h" |
54 | #include "llvm/Support/Debug.h" |
55 | #include "llvm/Support/raw_ostream.h" |
56 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
57 | #include "llvm/Transforms/Utils/Cloning.h" |
58 | #include "llvm/Transforms/Utils/ModuleUtils.h" |
59 | #include <optional> |
60 | |
61 | #define DEBUG_TYPE "llvm-dialect-to-llvm-ir" |
62 | |
63 | using namespace mlir; |
64 | using namespace mlir::LLVM; |
65 | using namespace mlir::LLVM::detail; |
66 | |
67 | #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" |
68 | |
69 | namespace { |
70 | /// A customized inserter for LLVM's IRBuilder that captures all LLVM IR |
71 | /// instructions that are created for future reference. |
72 | /// |
73 | /// This is intended to be used with the `CollectionScope` RAII object: |
74 | /// |
75 | /// llvm::IRBuilder<..., InstructionCapturingInserter> builder; |
76 | /// { |
77 | /// InstructionCapturingInserter::CollectionScope scope(builder); |
78 | /// // Call IRBuilder methods as usual. |
79 | /// |
80 | /// // This will return a list of all instructions created by the builder, |
81 | /// // in order of creation. |
82 | /// builder.getInserter().getCapturedInstructions(); |
83 | /// } |
84 | /// // This will return an empty list. |
85 | /// builder.getInserter().getCapturedInstructions(); |
86 | /// |
87 | /// The capturing functionality is _disabled_ by default for performance |
88 | /// consideration. It needs to be explicitly enabled, which is achieved by |
89 | /// creating a `CollectionScope`. |
90 | class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter { |
91 | public: |
92 | /// Constructs the inserter. |
93 | InstructionCapturingInserter() |
94 | : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) { |
95 | if (LLVM_LIKELY(enabled)) |
96 | capturedInstructions.push_back(instruction); |
97 | }) {} |
98 | |
99 | /// Returns the list of LLVM IR instructions captured since the last cleanup. |
100 | ArrayRef<llvm::Instruction *> getCapturedInstructions() const { |
101 | return capturedInstructions; |
102 | } |
103 | |
104 | /// Clears the list of captured LLVM IR instructions. |
105 | void clearCapturedInstructions() { capturedInstructions.clear(); } |
106 | |
107 | /// RAII object enabling the capture of created LLVM IR instructions. |
108 | class CollectionScope { |
109 | public: |
110 | /// Creates the scope for the given inserter. |
111 | CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing); |
112 | |
113 | /// Ends the scope. |
114 | ~CollectionScope(); |
115 | |
116 | ArrayRef<llvm::Instruction *> getCapturedInstructions() { |
117 | if (!inserter) |
118 | return {}; |
119 | return inserter->getCapturedInstructions(); |
120 | } |
121 | |
122 | private: |
123 | /// Back reference to the inserter. |
124 | InstructionCapturingInserter *inserter = nullptr; |
125 | |
126 | /// List of instructions in the inserter prior to this scope. |
127 | SmallVector<llvm::Instruction *> previouslyCollectedInstructions; |
128 | |
129 | /// Whether the inserter was enabled prior to this scope. |
130 | bool wasEnabled; |
131 | }; |
132 | |
133 | /// Enable or disable the capturing mechanism. |
134 | void setEnabled(bool enabled = true) { this->enabled = enabled; } |
135 | |
136 | private: |
137 | /// List of captured instructions. |
138 | SmallVector<llvm::Instruction *> capturedInstructions; |
139 | |
140 | /// Whether the collection is enabled. |
141 | bool enabled = false; |
142 | }; |
143 | |
144 | using CapturingIRBuilder = |
145 | llvm::IRBuilder<llvm::ConstantFolder, InstructionCapturingInserter>; |
146 | } // namespace |
147 | |
148 | InstructionCapturingInserter::CollectionScope::CollectionScope( |
149 | llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) { |
150 | |
151 | if (!isBuilderCapturing) |
152 | return; |
153 | |
154 | auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder); |
155 | inserter = &capturingIRBuilder.getInserter(); |
156 | wasEnabled = inserter->enabled; |
157 | if (wasEnabled) |
158 | previouslyCollectedInstructions.swap(inserter->capturedInstructions); |
159 | inserter->setEnabled(true); |
160 | } |
161 | |
162 | InstructionCapturingInserter::CollectionScope::~CollectionScope() { |
163 | if (!inserter) |
164 | return; |
165 | |
166 | previouslyCollectedInstructions.swap(inserter->capturedInstructions); |
167 | // If collection was enabled (likely in another, surrounding scope), keep |
168 | // the instructions collected in this scope. |
169 | if (wasEnabled) { |
170 | llvm::append_range(inserter->capturedInstructions, |
171 | previouslyCollectedInstructions); |
172 | } |
173 | inserter->setEnabled(wasEnabled); |
174 | } |
175 | |
176 | /// Translates the given data layout spec attribute to the LLVM IR data layout. |
177 | /// Only integer, float, pointer and endianness entries are currently supported. |
178 | static FailureOr<llvm::DataLayout> |
179 | translateDataLayout(DataLayoutSpecInterface attribute, |
180 | const DataLayout &dataLayout, |
181 | std::optional<Location> loc = std::nullopt) { |
182 | if (!loc) |
183 | loc = UnknownLoc::get(attribute.getContext()); |
184 | |
185 | // Translate the endianness attribute. |
186 | std::string llvmDataLayout; |
187 | llvm::raw_string_ostream layoutStream(llvmDataLayout); |
188 | for (DataLayoutEntryInterface entry : attribute.getEntries()) { |
189 | auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey()); |
190 | if (!key) |
191 | continue; |
192 | if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) { |
193 | auto value = cast<StringAttr>(entry.getValue()); |
194 | bool isLittleEndian = |
195 | value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle; |
196 | layoutStream << "-" << (isLittleEndian ? "e" : "E" ); |
197 | layoutStream.flush(); |
198 | continue; |
199 | } |
200 | if (key.getValue() == DLTIDialect::kDataLayoutProgramMemorySpaceKey) { |
201 | auto value = cast<IntegerAttr>(entry.getValue()); |
202 | uint64_t space = value.getValue().getZExtValue(); |
203 | // Skip the default address space. |
204 | if (space == 0) |
205 | continue; |
206 | layoutStream << "-P" << space; |
207 | layoutStream.flush(); |
208 | continue; |
209 | } |
210 | if (key.getValue() == DLTIDialect::kDataLayoutGlobalMemorySpaceKey) { |
211 | auto value = cast<IntegerAttr>(entry.getValue()); |
212 | uint64_t space = value.getValue().getZExtValue(); |
213 | // Skip the default address space. |
214 | if (space == 0) |
215 | continue; |
216 | layoutStream << "-G" << space; |
217 | layoutStream.flush(); |
218 | continue; |
219 | } |
220 | if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) { |
221 | auto value = cast<IntegerAttr>(entry.getValue()); |
222 | uint64_t space = value.getValue().getZExtValue(); |
223 | // Skip the default address space. |
224 | if (space == 0) |
225 | continue; |
226 | layoutStream << "-A" << space; |
227 | layoutStream.flush(); |
228 | continue; |
229 | } |
230 | if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) { |
231 | auto value = cast<IntegerAttr>(entry.getValue()); |
232 | uint64_t alignment = value.getValue().getZExtValue(); |
233 | // Skip the default stack alignment. |
234 | if (alignment == 0) |
235 | continue; |
236 | layoutStream << "-S" << alignment; |
237 | layoutStream.flush(); |
238 | continue; |
239 | } |
240 | emitError(*loc) << "unsupported data layout key " << key; |
241 | return failure(); |
242 | } |
243 | |
244 | // Go through the list of entries to check which types are explicitly |
245 | // specified in entries. Where possible, data layout queries are used instead |
246 | // of directly inspecting the entries. |
247 | for (DataLayoutEntryInterface entry : attribute.getEntries()) { |
248 | auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()); |
249 | if (!type) |
250 | continue; |
251 | // Data layout for the index type is irrelevant at this point. |
252 | if (isa<IndexType>(type)) |
253 | continue; |
254 | layoutStream << "-" ; |
255 | LogicalResult result = |
256 | llvm::TypeSwitch<Type, LogicalResult>(type) |
257 | .Case<IntegerType, Float16Type, Float32Type, Float64Type, |
258 | Float80Type, Float128Type>([&](Type type) -> LogicalResult { |
259 | if (auto intType = dyn_cast<IntegerType>(type)) { |
260 | if (intType.getSignedness() != IntegerType::Signless) |
261 | return emitError(*loc) |
262 | << "unsupported data layout for non-signless integer " |
263 | << intType; |
264 | layoutStream << "i" ; |
265 | } else { |
266 | layoutStream << "f" ; |
267 | } |
268 | uint64_t size = dataLayout.getTypeSizeInBits(type); |
269 | uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u; |
270 | uint64_t preferred = |
271 | dataLayout.getTypePreferredAlignment(type) * 8u; |
272 | layoutStream << size << ":" << abi; |
273 | if (abi != preferred) |
274 | layoutStream << ":" << preferred; |
275 | return success(); |
276 | }) |
277 | .Case([&](LLVMPointerType type) { |
278 | layoutStream << "p" << type.getAddressSpace() << ":" ; |
279 | uint64_t size = dataLayout.getTypeSizeInBits(type); |
280 | uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u; |
281 | uint64_t preferred = |
282 | dataLayout.getTypePreferredAlignment(type) * 8u; |
283 | uint64_t index = *dataLayout.getTypeIndexBitwidth(type); |
284 | layoutStream << size << ":" << abi << ":" << preferred << ":" |
285 | << index; |
286 | return success(); |
287 | }) |
288 | .Default([loc](Type type) { |
289 | return emitError(*loc) |
290 | << "unsupported type in data layout: " << type; |
291 | }); |
292 | if (failed(result)) |
293 | return failure(); |
294 | } |
295 | layoutStream.flush(); |
296 | StringRef layoutSpec(llvmDataLayout); |
297 | if (layoutSpec.starts_with(Prefix: "-" )) |
298 | layoutSpec = layoutSpec.drop_front(); |
299 | |
300 | return llvm::DataLayout(layoutSpec); |
301 | } |
302 | |
303 | /// Builds a constant of a sequential LLVM type `type`, potentially containing |
304 | /// other sequential types recursively, from the individual constant values |
305 | /// provided in `constants`. `shape` contains the number of elements in nested |
306 | /// sequential types. Reports errors at `loc` and returns nullptr on error. |
307 | static llvm::Constant * |
308 | buildSequentialConstant(ArrayRef<llvm::Constant *> &constants, |
309 | ArrayRef<int64_t> shape, llvm::Type *type, |
310 | Location loc) { |
311 | if (shape.empty()) { |
312 | llvm::Constant *result = constants.front(); |
313 | constants = constants.drop_front(); |
314 | return result; |
315 | } |
316 | |
317 | llvm::Type *elementType; |
318 | if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: type)) { |
319 | elementType = arrayTy->getElementType(); |
320 | } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(Val: type)) { |
321 | elementType = vectorTy->getElementType(); |
322 | } else { |
323 | emitError(loc) << "expected sequential LLVM types wrapping a scalar" ; |
324 | return nullptr; |
325 | } |
326 | |
327 | SmallVector<llvm::Constant *, 8> nested; |
328 | nested.reserve(N: shape.front()); |
329 | for (int64_t i = 0; i < shape.front(); ++i) { |
330 | nested.push_back(Elt: buildSequentialConstant(constants, shape: shape.drop_front(), |
331 | type: elementType, loc)); |
332 | if (!nested.back()) |
333 | return nullptr; |
334 | } |
335 | |
336 | if (shape.size() == 1 && type->isVectorTy()) |
337 | return llvm::ConstantVector::get(V: nested); |
338 | return llvm::ConstantArray::get( |
339 | T: llvm::ArrayType::get(ElementType: elementType, NumElements: shape.front()), V: nested); |
340 | } |
341 | |
342 | /// Returns the first non-sequential type nested in sequential types. |
343 | static llvm::Type *getInnermostElementType(llvm::Type *type) { |
344 | do { |
345 | if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: type)) { |
346 | type = arrayTy->getElementType(); |
347 | } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(Val: type)) { |
348 | type = vectorTy->getElementType(); |
349 | } else { |
350 | return type; |
351 | } |
352 | } while (true); |
353 | } |
354 | |
355 | /// Convert a dense elements attribute to an LLVM IR constant using its raw data |
356 | /// storage if possible. This supports elements attributes of tensor or vector |
357 | /// type and avoids constructing separate objects for individual values of the |
358 | /// innermost dimension. Constants for other dimensions are still constructed |
359 | /// recursively. Returns null if constructing from raw data is not supported for |
360 | /// this type, e.g., element type is not a power-of-two-sized primitive. Reports |
361 | /// other errors at `loc`. |
362 | static llvm::Constant * |
363 | convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, |
364 | llvm::Type *llvmType, |
365 | const ModuleTranslation &moduleTranslation) { |
366 | if (!denseElementsAttr) |
367 | return nullptr; |
368 | |
369 | llvm::Type *innermostLLVMType = getInnermostElementType(type: llvmType); |
370 | if (!llvm::ConstantDataSequential::isElementTypeCompatible(Ty: innermostLLVMType)) |
371 | return nullptr; |
372 | |
373 | ShapedType type = denseElementsAttr.getType(); |
374 | if (type.getNumElements() == 0) |
375 | return nullptr; |
376 | |
377 | // Check that the raw data size matches what is expected for the scalar size. |
378 | // TODO: in theory, we could repack the data here to keep constructing from |
379 | // raw data. |
380 | // TODO: we may also need to consider endianness when cross-compiling to an |
381 | // architecture where it is different. |
382 | int64_t elementByteSize = denseElementsAttr.getRawData().size() / |
383 | denseElementsAttr.getNumElements(); |
384 | if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) |
385 | return nullptr; |
386 | |
387 | // Compute the shape of all dimensions but the innermost. Note that the |
388 | // innermost dimension may be that of the vector element type. |
389 | bool hasVectorElementType = isa<VectorType>(type.getElementType()); |
390 | int64_t numAggregates = |
391 | denseElementsAttr.getNumElements() / |
392 | (hasVectorElementType ? 1 |
393 | : denseElementsAttr.getType().getShape().back()); |
394 | ArrayRef<int64_t> outerShape = type.getShape(); |
395 | if (!hasVectorElementType) |
396 | outerShape = outerShape.drop_back(); |
397 | |
398 | // Handle the case of vector splat, LLVM has special support for it. |
399 | if (denseElementsAttr.isSplat() && |
400 | (isa<VectorType>(type) || hasVectorElementType)) { |
401 | llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( |
402 | llvmType: innermostLLVMType, attr: denseElementsAttr.getSplatValue<Attribute>(), loc, |
403 | moduleTranslation); |
404 | llvm::Constant *splatVector = |
405 | llvm::ConstantDataVector::getSplat(NumElts: 0, Elt: splatValue); |
406 | SmallVector<llvm::Constant *> constants(numAggregates, splatVector); |
407 | ArrayRef<llvm::Constant *> constantsRef = constants; |
408 | return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc); |
409 | } |
410 | if (denseElementsAttr.isSplat()) |
411 | return nullptr; |
412 | |
413 | // In case of non-splat, create a constructor for the innermost constant from |
414 | // a piece of raw data. |
415 | std::function<llvm::Constant *(StringRef)> buildCstData; |
416 | if (isa<TensorType>(type)) { |
417 | auto vectorElementType = dyn_cast<VectorType>(type.getElementType()); |
418 | if (vectorElementType && vectorElementType.getRank() == 1) { |
419 | buildCstData = [&](StringRef data) { |
420 | return llvm::ConstantDataVector::getRaw( |
421 | data, vectorElementType.getShape().back(), innermostLLVMType); |
422 | }; |
423 | } else if (!vectorElementType) { |
424 | buildCstData = [&](StringRef data) { |
425 | return llvm::ConstantDataArray::getRaw(data, type.getShape().back(), |
426 | innermostLLVMType); |
427 | }; |
428 | } |
429 | } else if (isa<VectorType>(type)) { |
430 | buildCstData = [&](StringRef data) { |
431 | return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), |
432 | innermostLLVMType); |
433 | }; |
434 | } |
435 | if (!buildCstData) |
436 | return nullptr; |
437 | |
438 | // Create innermost constants and defer to the default constant creation |
439 | // mechanism for other dimensions. |
440 | SmallVector<llvm::Constant *> constants; |
441 | int64_t aggregateSize = denseElementsAttr.getType().getShape().back() * |
442 | (innermostLLVMType->getScalarSizeInBits() / 8); |
443 | constants.reserve(N: numAggregates); |
444 | for (unsigned i = 0; i < numAggregates; ++i) { |
445 | StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize, |
446 | aggregateSize); |
447 | constants.push_back(Elt: buildCstData(data)); |
448 | } |
449 | |
450 | ArrayRef<llvm::Constant *> constantsRef = constants; |
451 | return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc); |
452 | } |
453 | |
454 | /// Convert a dense resource elements attribute to an LLVM IR constant using its |
455 | /// raw data storage if possible. This supports elements attributes of tensor or |
456 | /// vector type and avoids constructing separate objects for individual values |
457 | /// of the innermost dimension. Constants for other dimensions are still |
458 | /// constructed recursively. Returns nullptr on failure and emits errors at |
459 | /// `loc`. |
460 | static llvm::Constant *convertDenseResourceElementsAttr( |
461 | Location loc, DenseResourceElementsAttr denseResourceAttr, |
462 | llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) { |
463 | assert(denseResourceAttr && "expected non-null attribute" ); |
464 | |
465 | llvm::Type *innermostLLVMType = getInnermostElementType(type: llvmType); |
466 | if (!llvm::ConstantDataSequential::isElementTypeCompatible( |
467 | Ty: innermostLLVMType)) { |
468 | emitError(loc, message: "no known conversion for innermost element type" ); |
469 | return nullptr; |
470 | } |
471 | |
472 | ShapedType type = denseResourceAttr.getType(); |
473 | assert(type.getNumElements() > 0 && "Expected non-empty elements attribute" ); |
474 | |
475 | AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob(); |
476 | if (!blob) { |
477 | emitError(loc, message: "resource does not exist" ); |
478 | return nullptr; |
479 | } |
480 | |
481 | ArrayRef<char> rawData = blob->getData(); |
482 | |
483 | // Check that the raw data size matches what is expected for the scalar size. |
484 | // TODO: in theory, we could repack the data here to keep constructing from |
485 | // raw data. |
486 | // TODO: we may also need to consider endianness when cross-compiling to an |
487 | // architecture where it is different. |
488 | int64_t numElements = denseResourceAttr.getType().getNumElements(); |
489 | int64_t elementByteSize = rawData.size() / numElements; |
490 | if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) { |
491 | emitError(loc, message: "raw data size does not match element type size" ); |
492 | return nullptr; |
493 | } |
494 | |
495 | // Compute the shape of all dimensions but the innermost. Note that the |
496 | // innermost dimension may be that of the vector element type. |
497 | bool hasVectorElementType = isa<VectorType>(type.getElementType()); |
498 | int64_t numAggregates = |
499 | numElements / (hasVectorElementType |
500 | ? 1 |
501 | : denseResourceAttr.getType().getShape().back()); |
502 | ArrayRef<int64_t> outerShape = type.getShape(); |
503 | if (!hasVectorElementType) |
504 | outerShape = outerShape.drop_back(); |
505 | |
506 | // Create a constructor for the innermost constant from a piece of raw data. |
507 | std::function<llvm::Constant *(StringRef)> buildCstData; |
508 | if (isa<TensorType>(type)) { |
509 | auto vectorElementType = dyn_cast<VectorType>(type.getElementType()); |
510 | if (vectorElementType && vectorElementType.getRank() == 1) { |
511 | buildCstData = [&](StringRef data) { |
512 | return llvm::ConstantDataVector::getRaw( |
513 | data, vectorElementType.getShape().back(), innermostLLVMType); |
514 | }; |
515 | } else if (!vectorElementType) { |
516 | buildCstData = [&](StringRef data) { |
517 | return llvm::ConstantDataArray::getRaw(data, type.getShape().back(), |
518 | innermostLLVMType); |
519 | }; |
520 | } |
521 | } else if (isa<VectorType>(type)) { |
522 | buildCstData = [&](StringRef data) { |
523 | return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), |
524 | innermostLLVMType); |
525 | }; |
526 | } |
527 | if (!buildCstData) { |
528 | emitError(loc, message: "unsupported dense_resource type" ); |
529 | return nullptr; |
530 | } |
531 | |
532 | // Create innermost constants and defer to the default constant creation |
533 | // mechanism for other dimensions. |
534 | SmallVector<llvm::Constant *> constants; |
535 | int64_t aggregateSize = denseResourceAttr.getType().getShape().back() * |
536 | (innermostLLVMType->getScalarSizeInBits() / 8); |
537 | constants.reserve(N: numAggregates); |
538 | for (unsigned i = 0; i < numAggregates; ++i) { |
539 | StringRef data(rawData.data() + i * aggregateSize, aggregateSize); |
540 | constants.push_back(Elt: buildCstData(data)); |
541 | } |
542 | |
543 | ArrayRef<llvm::Constant *> constantsRef = constants; |
544 | return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc); |
545 | } |
546 | |
547 | /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. |
548 | /// This currently supports integer, floating point, splat and dense element |
549 | /// attributes and combinations thereof. Also, an array attribute with two |
550 | /// elements is supported to represent a complex constant. In case of error, |
551 | /// report it to `loc` and return nullptr. |
552 | llvm::Constant *mlir::LLVM::detail::getLLVMConstant( |
553 | llvm::Type *llvmType, Attribute attr, Location loc, |
554 | const ModuleTranslation &moduleTranslation) { |
555 | if (!attr) |
556 | return llvm::UndefValue::get(T: llvmType); |
557 | if (auto *structType = dyn_cast<::llvm::StructType>(Val: llvmType)) { |
558 | auto arrayAttr = dyn_cast<ArrayAttr>(attr); |
559 | if (!arrayAttr || arrayAttr.size() != 2) { |
560 | emitError(loc, message: "expected struct type to be a complex number" ); |
561 | return nullptr; |
562 | } |
563 | llvm::Type *elementType = structType->getElementType(N: 0); |
564 | llvm::Constant *real = |
565 | getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation); |
566 | if (!real) |
567 | return nullptr; |
568 | llvm::Constant *imag = |
569 | getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation); |
570 | if (!imag) |
571 | return nullptr; |
572 | return llvm::ConstantStruct::get(T: structType, V: {real, imag}); |
573 | } |
574 | // For integer types, we allow a mismatch in sizes as the index type in |
575 | // MLIR might have a different size than the index type in the LLVM module. |
576 | if (auto intAttr = dyn_cast<IntegerAttr>(attr)) |
577 | return llvm::ConstantInt::get( |
578 | llvmType, |
579 | intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); |
580 | if (auto floatAttr = dyn_cast<FloatAttr>(attr)) { |
581 | const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); |
582 | // Special case for 8-bit floats, which are represented by integers due to |
583 | // the lack of native fp8 types in LLVM at the moment. Additionally, handle |
584 | // targets (like AMDGPU) that don't implement bfloat and convert all bfloats |
585 | // to i16. |
586 | unsigned floatWidth = APFloat::getSizeInBits(Sem: sem); |
587 | if (llvmType->isIntegerTy(Bitwidth: floatWidth)) |
588 | return llvm::ConstantInt::get(llvmType, |
589 | floatAttr.getValue().bitcastToAPInt()); |
590 | if (llvmType != |
591 | llvm::Type::getFloatingPointTy(C&: llvmType->getContext(), |
592 | S: floatAttr.getValue().getSemantics())) { |
593 | emitError(loc, message: "FloatAttr does not match expected type of the constant" ); |
594 | return nullptr; |
595 | } |
596 | return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); |
597 | } |
598 | if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr)) |
599 | return llvm::ConstantExpr::getBitCast( |
600 | C: moduleTranslation.lookupFunction(name: funcAttr.getValue()), Ty: llvmType); |
601 | if (auto splatAttr = dyn_cast<SplatElementsAttr>(Val&: attr)) { |
602 | llvm::Type *elementType; |
603 | uint64_t numElements; |
604 | bool isScalable = false; |
605 | if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: llvmType)) { |
606 | elementType = arrayTy->getElementType(); |
607 | numElements = arrayTy->getNumElements(); |
608 | } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(Val: llvmType)) { |
609 | elementType = fVectorTy->getElementType(); |
610 | numElements = fVectorTy->getNumElements(); |
611 | } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(Val: llvmType)) { |
612 | elementType = sVectorTy->getElementType(); |
613 | numElements = sVectorTy->getMinNumElements(); |
614 | isScalable = true; |
615 | } else { |
616 | llvm_unreachable("unrecognized constant vector type" ); |
617 | } |
618 | // Splat value is a scalar. Extract it only if the element type is not |
619 | // another sequence type. The recursion terminates because each step removes |
620 | // one outer sequential type. |
621 | bool elementTypeSequential = |
622 | isa<llvm::ArrayType, llvm::VectorType>(Val: elementType); |
623 | llvm::Constant *child = getLLVMConstant( |
624 | llvmType: elementType, |
625 | attr: elementTypeSequential ? splatAttr |
626 | : splatAttr.getSplatValue<Attribute>(), |
627 | loc, moduleTranslation); |
628 | if (!child) |
629 | return nullptr; |
630 | if (llvmType->isVectorTy()) |
631 | return llvm::ConstantVector::getSplat( |
632 | EC: llvm::ElementCount::get(MinVal: numElements, /*Scalable=*/isScalable), Elt: child); |
633 | if (llvmType->isArrayTy()) { |
634 | auto *arrayType = llvm::ArrayType::get(ElementType: elementType, NumElements: numElements); |
635 | SmallVector<llvm::Constant *, 8> constants(numElements, child); |
636 | return llvm::ConstantArray::get(T: arrayType, V: constants); |
637 | } |
638 | } |
639 | |
640 | // Try using raw elements data if possible. |
641 | if (llvm::Constant *result = |
642 | convertDenseElementsAttr(loc, denseElementsAttr: dyn_cast<DenseElementsAttr>(Val&: attr), |
643 | llvmType, moduleTranslation)) { |
644 | return result; |
645 | } |
646 | |
647 | if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) { |
648 | return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType, |
649 | moduleTranslation); |
650 | } |
651 | |
652 | // Fall back to element-by-element construction otherwise. |
653 | if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) { |
654 | assert(elementsAttr.getShapedType().hasStaticShape()); |
655 | assert(!elementsAttr.getShapedType().getShape().empty() && |
656 | "unexpected empty elements attribute shape" ); |
657 | |
658 | SmallVector<llvm::Constant *, 8> constants; |
659 | constants.reserve(N: elementsAttr.getNumElements()); |
660 | llvm::Type *innermostType = getInnermostElementType(type: llvmType); |
661 | for (auto n : elementsAttr.getValues<Attribute>()) { |
662 | constants.push_back( |
663 | getLLVMConstant(innermostType, n, loc, moduleTranslation)); |
664 | if (!constants.back()) |
665 | return nullptr; |
666 | } |
667 | ArrayRef<llvm::Constant *> constantsRef = constants; |
668 | llvm::Constant *result = buildSequentialConstant( |
669 | constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc); |
670 | assert(constantsRef.empty() && "did not consume all elemental constants" ); |
671 | return result; |
672 | } |
673 | |
674 | if (auto stringAttr = dyn_cast<StringAttr>(attr)) { |
675 | return llvm::ConstantDataArray::get( |
676 | Context&: moduleTranslation.getLLVMContext(), |
677 | Elts: ArrayRef<char>{stringAttr.getValue().data(), |
678 | stringAttr.getValue().size()}); |
679 | } |
680 | emitError(loc, message: "unsupported constant value" ); |
681 | return nullptr; |
682 | } |
683 | |
684 | ModuleTranslation::ModuleTranslation(Operation *module, |
685 | std::unique_ptr<llvm::Module> llvmModule) |
686 | : mlirModule(module), llvmModule(std::move(llvmModule)), |
687 | debugTranslation( |
688 | std::make_unique<DebugTranslation>(args&: module, args&: *this->llvmModule)), |
689 | loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>( |
690 | args&: *this, args&: *this->llvmModule)), |
691 | typeTranslator(this->llvmModule->getContext()), |
692 | iface(module->getContext()) { |
693 | assert(satisfiesLLVMModule(mlirModule) && |
694 | "mlirModule should honor LLVM's module semantics." ); |
695 | } |
696 | |
697 | ModuleTranslation::~ModuleTranslation() { |
698 | if (ompBuilder) |
699 | ompBuilder->finalize(); |
700 | } |
701 | |
702 | void ModuleTranslation::forgetMapping(Region ®ion) { |
703 | SmallVector<Region *> toProcess; |
704 | toProcess.push_back(Elt: ®ion); |
705 | while (!toProcess.empty()) { |
706 | Region *current = toProcess.pop_back_val(); |
707 | for (Block &block : *current) { |
708 | blockMapping.erase(Val: &block); |
709 | for (Value arg : block.getArguments()) |
710 | valueMapping.erase(Val: arg); |
711 | for (Operation &op : block) { |
712 | for (Value value : op.getResults()) |
713 | valueMapping.erase(Val: value); |
714 | if (op.hasSuccessors()) |
715 | branchMapping.erase(Val: &op); |
716 | if (isa<LLVM::GlobalOp>(op)) |
717 | globalsMapping.erase(Val: &op); |
718 | if (isa<LLVM::CallOp>(op)) |
719 | callMapping.erase(Val: &op); |
720 | llvm::append_range( |
721 | C&: toProcess, |
722 | R: llvm::map_range(C: op.getRegions(), F: [](Region &r) { return &r; })); |
723 | } |
724 | } |
725 | } |
726 | } |
727 | |
728 | /// Get the SSA value passed to the current block from the terminator operation |
729 | /// of its predecessor. |
730 | static Value getPHISourceValue(Block *current, Block *pred, |
731 | unsigned numArguments, unsigned index) { |
732 | Operation &terminator = *pred->getTerminator(); |
733 | if (isa<LLVM::BrOp>(terminator)) |
734 | return terminator.getOperand(idx: index); |
735 | |
736 | #ifndef NDEBUG |
737 | llvm::SmallPtrSet<Block *, 4> seenSuccessors; |
738 | for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) { |
739 | Block *successor = terminator.getSuccessor(index: i); |
740 | auto branch = cast<BranchOpInterface>(terminator); |
741 | SuccessorOperands successorOperands = branch.getSuccessorOperands(i); |
742 | assert( |
743 | (!seenSuccessors.contains(successor) || successorOperands.empty()) && |
744 | "successors with arguments in LLVM branches must be different blocks" ); |
745 | seenSuccessors.insert(Ptr: successor); |
746 | } |
747 | #endif |
748 | |
749 | // For instructions that branch based on a condition value, we need to take |
750 | // the operands for the branch that was taken. |
751 | if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) { |
752 | // For conditional branches, we take the operands from either the "true" or |
753 | // the "false" branch. |
754 | return condBranchOp.getSuccessor(0) == current |
755 | ? condBranchOp.getTrueDestOperands()[index] |
756 | : condBranchOp.getFalseDestOperands()[index]; |
757 | } |
758 | |
759 | if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) { |
760 | // For switches, we take the operands from either the default case, or from |
761 | // the case branch that was taken. |
762 | if (switchOp.getDefaultDestination() == current) |
763 | return switchOp.getDefaultOperands()[index]; |
764 | for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations())) |
765 | if (i.value() == current) |
766 | return switchOp.getCaseOperands(i.index())[index]; |
767 | } |
768 | |
769 | if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) { |
770 | return invokeOp.getNormalDest() == current |
771 | ? invokeOp.getNormalDestOperands()[index] |
772 | : invokeOp.getUnwindDestOperands()[index]; |
773 | } |
774 | |
775 | llvm_unreachable( |
776 | "only branch, switch or invoke operations can be terminators " |
777 | "of a block that has successors" ); |
778 | } |
779 | |
780 | /// Connect the PHI nodes to the results of preceding blocks. |
781 | void mlir::LLVM::detail::connectPHINodes(Region ®ion, |
782 | const ModuleTranslation &state) { |
783 | // Skip the first block, it cannot be branched to and its arguments correspond |
784 | // to the arguments of the LLVM function. |
785 | for (Block &bb : llvm::drop_begin(RangeOrContainer&: region)) { |
786 | llvm::BasicBlock *llvmBB = state.lookupBlock(block: &bb); |
787 | auto phis = llvmBB->phis(); |
788 | auto numArguments = bb.getNumArguments(); |
789 | assert(numArguments == std::distance(phis.begin(), phis.end())); |
790 | for (auto [index, phiNode] : llvm::enumerate(First&: phis)) { |
791 | for (auto *pred : bb.getPredecessors()) { |
792 | // Find the LLVM IR block that contains the converted terminator |
793 | // instruction and use it in the PHI node. Note that this block is not |
794 | // necessarily the same as state.lookupBlock(pred), some operations |
795 | // (in particular, OpenMP operations using OpenMPIRBuilder) may have |
796 | // split the blocks. |
797 | llvm::Instruction *terminator = |
798 | state.lookupBranch(op: pred->getTerminator()); |
799 | assert(terminator && "missing the mapping for a terminator" ); |
800 | phiNode.addIncoming(V: state.lookupValue(value: getPHISourceValue( |
801 | current: &bb, pred, numArguments, index)), |
802 | BB: terminator->getParent()); |
803 | } |
804 | } |
805 | } |
806 | } |
807 | |
808 | llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( |
809 | llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, |
810 | ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) { |
811 | llvm::Module *module = builder.GetInsertBlock()->getModule(); |
812 | llvm::Function *fn = llvm::Intrinsic::getDeclaration(M: module, id: intrinsic, Tys: tys); |
813 | return builder.CreateCall(Callee: fn, Args: args); |
814 | } |
815 | |
816 | llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( |
817 | llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation, |
818 | Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults, |
819 | ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands, |
820 | ArrayRef<unsigned> immArgPositions, |
821 | ArrayRef<StringLiteral> immArgAttrNames) { |
822 | assert(immArgPositions.size() == immArgAttrNames.size() && |
823 | "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal " |
824 | "length" ); |
825 | |
826 | // Map operands and attributes to LLVM values. |
827 | auto operands = moduleTranslation.lookupValues(values: intrOp->getOperands()); |
828 | SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size()); |
829 | for (auto [immArgPos, immArgName] : |
830 | llvm::zip(t&: immArgPositions, u&: immArgAttrNames)) { |
831 | auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName)); |
832 | assert(attr.getType().isIntOrFloat() && "expected int or float immarg" ); |
833 | auto *type = moduleTranslation.convertType(type: attr.getType()); |
834 | args[immArgPos] = LLVM::detail::getLLVMConstant( |
835 | llvmType: type, attr: attr, loc: intrOp->getLoc(), moduleTranslation); |
836 | } |
837 | unsigned opArg = 0; |
838 | for (auto &arg : args) { |
839 | if (!arg) |
840 | arg = operands[opArg++]; |
841 | } |
842 | |
843 | // Resolve overloaded intrinsic declaration. |
844 | SmallVector<llvm::Type *> overloadedTypes; |
845 | for (unsigned overloadedResultIdx : overloadedResults) { |
846 | if (numResults > 1) { |
847 | // More than one result is mapped to an LLVM struct. |
848 | overloadedTypes.push_back(Elt: moduleTranslation.convertType( |
849 | type: llvm::cast<LLVM::LLVMStructType>(Val: intrOp->getResult(idx: 0).getType()) |
850 | .getBody()[overloadedResultIdx])); |
851 | } else { |
852 | overloadedTypes.push_back( |
853 | Elt: moduleTranslation.convertType(type: intrOp->getResult(idx: 0).getType())); |
854 | } |
855 | } |
856 | for (unsigned overloadedOperandIdx : overloadedOperands) |
857 | overloadedTypes.push_back(Elt: args[overloadedOperandIdx]->getType()); |
858 | llvm::Module *module = builder.GetInsertBlock()->getModule(); |
859 | llvm::Function *llvmIntr = |
860 | llvm::Intrinsic::getDeclaration(M: module, id: intrinsic, Tys: overloadedTypes); |
861 | |
862 | return builder.CreateCall(Callee: llvmIntr, Args: args); |
863 | } |
864 | |
865 | /// Given a single MLIR operation, create the corresponding LLVM IR operation |
866 | /// using the `builder`. |
867 | LogicalResult ModuleTranslation::convertOperation(Operation &op, |
868 | llvm::IRBuilderBase &builder, |
869 | bool recordInsertions) { |
870 | const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(obj: &op); |
871 | if (!opIface) |
872 | return op.emitError(message: "cannot be converted to LLVM IR: missing " |
873 | "`LLVMTranslationDialectInterface` registration for " |
874 | "dialect for op: " ) |
875 | << op.getName(); |
876 | |
877 | InstructionCapturingInserter::CollectionScope scope(builder, |
878 | recordInsertions); |
879 | if (failed(result: opIface->convertOperation(op: &op, builder, moduleTranslation&: *this))) |
880 | return op.emitError(message: "LLVM Translation failed for operation: " ) |
881 | << op.getName(); |
882 | |
883 | return convertDialectAttributes(op: &op, instructions: scope.getCapturedInstructions()); |
884 | } |
885 | |
886 | /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes |
887 | /// to define values corresponding to the MLIR block arguments. These nodes |
888 | /// are not connected to the source basic blocks, which may not exist yet. Uses |
889 | /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have |
890 | /// been created for `bb` and included in the block mapping. Inserts new |
891 | /// instructions at the end of the block and leaves `builder` in a state |
892 | /// suitable for further insertion into the end of the block. |
893 | LogicalResult ModuleTranslation::convertBlockImpl(Block &bb, |
894 | bool ignoreArguments, |
895 | llvm::IRBuilderBase &builder, |
896 | bool recordInsertions) { |
897 | builder.SetInsertPoint(lookupBlock(block: &bb)); |
898 | auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); |
899 | |
900 | // Before traversing operations, make block arguments available through |
901 | // value remapping and PHI nodes, but do not add incoming edges for the PHI |
902 | // nodes just yet: those values may be defined by this or following blocks. |
903 | // This step is omitted if "ignoreArguments" is set. The arguments of the |
904 | // first block have been already made available through the remapping of |
905 | // LLVM function arguments. |
906 | if (!ignoreArguments) { |
907 | auto predecessors = bb.getPredecessors(); |
908 | unsigned numPredecessors = |
909 | std::distance(first: predecessors.begin(), last: predecessors.end()); |
910 | for (auto arg : bb.getArguments()) { |
911 | auto wrappedType = arg.getType(); |
912 | if (!isCompatibleType(type: wrappedType)) |
913 | return emitError(loc: bb.front().getLoc(), |
914 | message: "block argument does not have an LLVM type" ); |
915 | llvm::Type *type = convertType(type: wrappedType); |
916 | llvm::PHINode *phi = builder.CreatePHI(Ty: type, NumReservedValues: numPredecessors); |
917 | mapValue(mlir: arg, llvm: phi); |
918 | } |
919 | } |
920 | |
921 | // Traverse operations. |
922 | for (auto &op : bb) { |
923 | // Set the current debug location within the builder. |
924 | builder.SetCurrentDebugLocation( |
925 | debugTranslation->translateLoc(loc: op.getLoc(), scope: subprogram)); |
926 | |
927 | if (failed(result: convertOperation(op, builder, recordInsertions))) |
928 | return failure(); |
929 | |
930 | // Set the branch weight metadata on the translated instruction. |
931 | if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) |
932 | setBranchWeightsMetadata(iface); |
933 | } |
934 | |
935 | return success(); |
936 | } |
937 | |
938 | /// A helper method to get the single Block in an operation honoring LLVM's |
939 | /// module requirements. |
940 | static Block &getModuleBody(Operation *module) { |
941 | return module->getRegion(index: 0).front(); |
942 | } |
943 | |
944 | /// A helper method to decide if a constant must not be set as a global variable |
945 | /// initializer. For an external linkage variable, the variable with an |
946 | /// initializer is considered externally visible and defined in this module, the |
947 | /// variable without an initializer is externally available and is defined |
948 | /// elsewhere. |
949 | static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, |
950 | llvm::Constant *cst) { |
951 | return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) || |
952 | linkage == llvm::GlobalVariable::ExternalWeakLinkage; |
953 | } |
954 | |
955 | /// Sets the runtime preemption specifier of `gv` to dso_local if |
956 | /// `dsoLocalRequested` is true, otherwise it is left unchanged. |
957 | static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, |
958 | llvm::GlobalValue *gv) { |
959 | if (dsoLocalRequested) |
960 | gv->setDSOLocal(true); |
961 | } |
962 | |
963 | /// Create named global variables that correspond to llvm.mlir.global |
964 | /// definitions. Convert llvm.global_ctors and global_dtors ops. |
965 | LogicalResult ModuleTranslation::convertGlobals() { |
966 | // Mapping from compile unit to its respective set of global variables. |
967 | DenseMap<llvm::DICompileUnit *, SmallVector<llvm::Metadata *>> allGVars; |
968 | |
969 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { |
970 | llvm::Type *type = convertType(op.getType()); |
971 | llvm::Constant *cst = nullptr; |
972 | if (op.getValueOrNull()) { |
973 | // String attributes are treated separately because they cannot appear as |
974 | // in-function constants and are thus not supported by getLLVMConstant. |
975 | if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) { |
976 | cst = llvm::ConstantDataArray::getString( |
977 | llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); |
978 | type = cst->getType(); |
979 | } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(), |
980 | *this))) { |
981 | return failure(); |
982 | } |
983 | } |
984 | |
985 | auto linkage = convertLinkageToLLVM(op.getLinkage()); |
986 | auto addrSpace = op.getAddrSpace(); |
987 | |
988 | // LLVM IR requires constant with linkage other than external or weak |
989 | // external to have initializers. If MLIR does not provide an initializer, |
990 | // default to undef. |
991 | bool dropInitializer = shouldDropGlobalInitializer(linkage, cst); |
992 | if (!dropInitializer && !cst) |
993 | cst = llvm::UndefValue::get(type); |
994 | else if (dropInitializer && cst) |
995 | cst = nullptr; |
996 | |
997 | auto *var = new llvm::GlobalVariable( |
998 | *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(), |
999 | /*InsertBefore=*/nullptr, |
1000 | op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel |
1001 | : llvm::GlobalValue::NotThreadLocal, |
1002 | addrSpace); |
1003 | |
1004 | if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) { |
1005 | auto selectorOp = cast<ComdatSelectorOp>( |
1006 | SymbolTable::lookupNearestSymbolFrom(op, *comdat)); |
1007 | var->setComdat(comdatMapping.lookup(selectorOp)); |
1008 | } |
1009 | |
1010 | if (op.getUnnamedAddr().has_value()) |
1011 | var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); |
1012 | |
1013 | if (op.getSection().has_value()) |
1014 | var->setSection(*op.getSection()); |
1015 | |
1016 | addRuntimePreemptionSpecifier(op.getDsoLocal(), var); |
1017 | |
1018 | std::optional<uint64_t> alignment = op.getAlignment(); |
1019 | if (alignment.has_value()) |
1020 | var->setAlignment(llvm::MaybeAlign(alignment.value())); |
1021 | |
1022 | var->setVisibility(convertVisibilityToLLVM(op.getVisibility_())); |
1023 | |
1024 | globalsMapping.try_emplace(op, var); |
1025 | |
1026 | // Add debug information if present. |
1027 | if (op.getDbgExpr()) { |
1028 | llvm::DIGlobalVariableExpression *diGlobalExpr = |
1029 | debugTranslation->translateGlobalVariableExpression(op.getDbgExpr()); |
1030 | llvm::DIGlobalVariable *diGlobalVar = diGlobalExpr->getVariable(); |
1031 | var->addDebugInfo(diGlobalExpr); |
1032 | |
1033 | // Get the compile unit (scope) of the the global variable. |
1034 | if (llvm::DICompileUnit *compileUnit = |
1035 | dyn_cast_if_present<llvm::DICompileUnit>( |
1036 | diGlobalVar->getScope())) { |
1037 | // Update the compile unit with this incoming global variable expression |
1038 | // during the finalizing step later. |
1039 | allGVars[compileUnit].push_back(diGlobalExpr); |
1040 | } |
1041 | } |
1042 | } |
1043 | |
1044 | // Convert global variable bodies. This is done after all global variables |
1045 | // have been created in LLVM IR because a global body may refer to another |
1046 | // global or itself. So all global variables need to be mapped first. |
1047 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { |
1048 | if (Block *initializer = op.getInitializerBlock()) { |
1049 | llvm::IRBuilder<> builder(llvmModule->getContext()); |
1050 | |
1051 | [[maybe_unused]] int numConstantsHit = 0; |
1052 | [[maybe_unused]] int numConstantsErased = 0; |
1053 | DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap; |
1054 | |
1055 | for (auto &op : initializer->without_terminator()) { |
1056 | if (failed(convertOperation(op, builder))) |
1057 | return emitError(op.getLoc(), "fail to convert global initializer" ); |
1058 | auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0))); |
1059 | if (!cst) |
1060 | return emitError(op.getLoc(), "unemittable constant value" ); |
1061 | |
1062 | // When emitting an LLVM constant, a new constant is created and the old |
1063 | // constant may become dangling and take space. We should remove the |
1064 | // dangling constants to avoid memory explosion especially for constant |
1065 | // arrays whose number of elements is large. |
1066 | // Because multiple operations may refer to the same constant, we need |
1067 | // to count the number of uses of each constant array and remove it only |
1068 | // when the count becomes zero. |
1069 | if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) { |
1070 | numConstantsHit++; |
1071 | Value result = op.getResult(0); |
1072 | int numUsers = std::distance(result.use_begin(), result.use_end()); |
1073 | auto [iterator, inserted] = |
1074 | constantAggregateUseMap.try_emplace(agg, numUsers); |
1075 | if (!inserted) { |
1076 | // Key already exists, update the value |
1077 | iterator->second += numUsers; |
1078 | } |
1079 | } |
1080 | // Scan the operands of the operation to decrement the use count of |
1081 | // constants. Erase the constant if the use count becomes zero. |
1082 | for (Value v : op.getOperands()) { |
1083 | auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v)); |
1084 | if (!cst) |
1085 | continue; |
1086 | auto iter = constantAggregateUseMap.find(cst); |
1087 | assert(iter != constantAggregateUseMap.end() && "constant not found" ); |
1088 | iter->second--; |
1089 | if (iter->second == 0) { |
1090 | // NOTE: cannot call removeDeadConstantUsers() here because it |
1091 | // may remove the constant which has uses not be converted yet. |
1092 | if (cst->user_empty()) { |
1093 | cst->destroyConstant(); |
1094 | numConstantsErased++; |
1095 | } |
1096 | constantAggregateUseMap.erase(iter); |
1097 | } |
1098 | } |
1099 | } |
1100 | |
1101 | ReturnOp ret = cast<ReturnOp>(initializer->getTerminator()); |
1102 | llvm::Constant *cst = |
1103 | cast<llvm::Constant>(lookupValue(ret.getOperand(0))); |
1104 | auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op)); |
1105 | if (!shouldDropGlobalInitializer(global->getLinkage(), cst)) |
1106 | global->setInitializer(cst); |
1107 | |
1108 | // Try to remove the dangling constants again after all operations are |
1109 | // converted. |
1110 | for (auto it : constantAggregateUseMap) { |
1111 | auto cst = it.first; |
1112 | cst->removeDeadConstantUsers(); |
1113 | if (cst->user_empty()) { |
1114 | cst->destroyConstant(); |
1115 | numConstantsErased++; |
1116 | } |
1117 | } |
1118 | |
1119 | LLVM_DEBUG(llvm::dbgs() |
1120 | << "Convert initializer for " << op.getName() << "\n" ; |
1121 | llvm::dbgs() << numConstantsHit << " new constants hit\n" ; |
1122 | llvm::dbgs() |
1123 | << numConstantsErased << " dangling constants erased\n" ;); |
1124 | } |
1125 | } |
1126 | |
1127 | // Convert llvm.mlir.global_ctors and dtors. |
1128 | for (Operation &op : getModuleBody(module: mlirModule)) { |
1129 | auto ctorOp = dyn_cast<GlobalCtorsOp>(op); |
1130 | auto dtorOp = dyn_cast<GlobalDtorsOp>(op); |
1131 | if (!ctorOp && !dtorOp) |
1132 | continue; |
1133 | auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities()) |
1134 | : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities()); |
1135 | auto appendGlobalFn = |
1136 | ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors; |
1137 | for (auto symbolAndPriority : range) { |
1138 | llvm::Function *f = lookupFunction( |
1139 | cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue()); |
1140 | appendGlobalFn(*llvmModule, f, |
1141 | cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(), |
1142 | /*Data=*/nullptr); |
1143 | } |
1144 | } |
1145 | |
1146 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) |
1147 | if (failed(convertDialectAttributes(op, {}))) |
1148 | return failure(); |
1149 | |
1150 | // Finally, update the compile units their respective sets of global variables |
1151 | // created earlier. |
1152 | for (const auto &[compileUnit, globals] : allGVars) { |
1153 | compileUnit->replaceGlobalVariables( |
1154 | N: llvm::MDTuple::get(Context&: getLLVMContext(), MDs: globals)); |
1155 | } |
1156 | |
1157 | return success(); |
1158 | } |
1159 | |
1160 | /// Attempts to add an attribute identified by `key`, optionally with the given |
1161 | /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the |
1162 | /// attribute has a kind known to LLVM IR, create the attribute of this kind, |
1163 | /// otherwise keep it as a string attribute. Performs additional checks for |
1164 | /// attributes known to have or not have a value in order to avoid assertions |
1165 | /// inside LLVM upon construction. |
1166 | static LogicalResult checkedAddLLVMFnAttribute(Location loc, |
1167 | llvm::Function *llvmFunc, |
1168 | StringRef key, |
1169 | StringRef value = StringRef()) { |
1170 | auto kind = llvm::Attribute::getAttrKindFromName(AttrName: key); |
1171 | if (kind == llvm::Attribute::None) { |
1172 | llvmFunc->addFnAttr(Kind: key, Val: value); |
1173 | return success(); |
1174 | } |
1175 | |
1176 | if (llvm::Attribute::isIntAttrKind(Kind: kind)) { |
1177 | if (value.empty()) |
1178 | return emitError(loc) << "LLVM attribute '" << key << "' expects a value" ; |
1179 | |
1180 | int64_t result; |
1181 | if (!value.getAsInteger(/*Radix=*/0, Result&: result)) |
1182 | llvmFunc->addFnAttr( |
1183 | Attr: llvm::Attribute::get(Context&: llvmFunc->getContext(), Kind: kind, Val: result)); |
1184 | else |
1185 | llvmFunc->addFnAttr(Kind: key, Val: value); |
1186 | return success(); |
1187 | } |
1188 | |
1189 | if (!value.empty()) |
1190 | return emitError(loc) << "LLVM attribute '" << key |
1191 | << "' does not expect a value, found '" << value |
1192 | << "'" ; |
1193 | |
1194 | llvmFunc->addFnAttr(Kind: kind); |
1195 | return success(); |
1196 | } |
1197 | |
1198 | /// Attaches the attributes listed in the given array attribute to `llvmFunc`. |
1199 | /// Reports error to `loc` if any and returns immediately. Expects `attributes` |
1200 | /// to be an array attribute containing either string attributes, treated as |
1201 | /// value-less LLVM attributes, or array attributes containing two string |
1202 | /// attributes, with the first string being the name of the corresponding LLVM |
1203 | /// attribute and the second string beings its value. Note that even integer |
1204 | /// attributes are expected to have their values expressed as strings. |
1205 | static LogicalResult |
1206 | forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes, |
1207 | llvm::Function *llvmFunc) { |
1208 | if (!attributes) |
1209 | return success(); |
1210 | |
1211 | for (Attribute attr : *attributes) { |
1212 | if (auto stringAttr = dyn_cast<StringAttr>(attr)) { |
1213 | if (failed( |
1214 | checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) |
1215 | return failure(); |
1216 | continue; |
1217 | } |
1218 | |
1219 | auto arrayAttr = dyn_cast<ArrayAttr>(attr); |
1220 | if (!arrayAttr || arrayAttr.size() != 2) |
1221 | return emitError(loc) |
1222 | << "expected 'passthrough' to contain string or array attributes" ; |
1223 | |
1224 | auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]); |
1225 | auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]); |
1226 | if (!keyAttr || !valueAttr) |
1227 | return emitError(loc) |
1228 | << "expected arrays within 'passthrough' to contain two strings" ; |
1229 | |
1230 | if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(), |
1231 | valueAttr.getValue()))) |
1232 | return failure(); |
1233 | } |
1234 | return success(); |
1235 | } |
1236 | |
1237 | LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { |
1238 | // Clear the block, branch value mappings, they are only relevant within one |
1239 | // function. |
1240 | blockMapping.clear(); |
1241 | valueMapping.clear(); |
1242 | branchMapping.clear(); |
1243 | llvm::Function *llvmFunc = lookupFunction(name: func.getName()); |
1244 | |
1245 | // Add function arguments to the value remapping table. |
1246 | for (auto [mlirArg, llvmArg] : |
1247 | llvm::zip(func.getArguments(), llvmFunc->args())) |
1248 | mapValue(mlirArg, &llvmArg); |
1249 | |
1250 | // Check the personality and set it. |
1251 | if (func.getPersonality()) { |
1252 | llvm::Type *ty = llvm::PointerType::getUnqual(C&: llvmFunc->getContext()); |
1253 | if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(), |
1254 | func.getLoc(), *this)) |
1255 | llvmFunc->setPersonalityFn(pfunc); |
1256 | } |
1257 | |
1258 | if (std::optional<StringRef> section = func.getSection()) |
1259 | llvmFunc->setSection(*section); |
1260 | |
1261 | if (func.getArmStreaming()) |
1262 | llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_enabled" ); |
1263 | else if (func.getArmLocallyStreaming()) |
1264 | llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_body" ); |
1265 | else if (func.getArmStreamingCompatible()) |
1266 | llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_compatible" ); |
1267 | |
1268 | if (func.getArmNewZa()) |
1269 | llvmFunc->addFnAttr(Kind: "aarch64_new_za" ); |
1270 | else if (func.getArmInZa()) |
1271 | llvmFunc->addFnAttr(Kind: "aarch64_in_za" ); |
1272 | else if (func.getArmOutZa()) |
1273 | llvmFunc->addFnAttr(Kind: "aarch64_out_za" ); |
1274 | else if (func.getArmInoutZa()) |
1275 | llvmFunc->addFnAttr(Kind: "aarch64_inout_za" ); |
1276 | else if (func.getArmPreservesZa()) |
1277 | llvmFunc->addFnAttr(Kind: "aarch64_preserves_za" ); |
1278 | |
1279 | if (auto targetCpu = func.getTargetCpu()) |
1280 | llvmFunc->addFnAttr("target-cpu" , *targetCpu); |
1281 | |
1282 | if (auto targetFeatures = func.getTargetFeatures()) |
1283 | llvmFunc->addFnAttr("target-features" , targetFeatures->getFeaturesString()); |
1284 | |
1285 | if (auto attr = func.getVscaleRange()) |
1286 | llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs( |
1287 | Context&: getLLVMContext(), MinValue: attr->getMinRange().getInt(), |
1288 | MaxValue: attr->getMaxRange().getInt())); |
1289 | |
1290 | if (auto unsafeFpMath = func.getUnsafeFpMath()) |
1291 | llvmFunc->addFnAttr("unsafe-fp-math" , llvm::toStringRef(*unsafeFpMath)); |
1292 | |
1293 | if (auto noInfsFpMath = func.getNoInfsFpMath()) |
1294 | llvmFunc->addFnAttr("no-infs-fp-math" , llvm::toStringRef(*noInfsFpMath)); |
1295 | |
1296 | if (auto noNansFpMath = func.getNoNansFpMath()) |
1297 | llvmFunc->addFnAttr("no-nans-fp-math" , llvm::toStringRef(*noNansFpMath)); |
1298 | |
1299 | if (auto approxFuncFpMath = func.getApproxFuncFpMath()) |
1300 | llvmFunc->addFnAttr("approx-func-fp-math" , |
1301 | llvm::toStringRef(*approxFuncFpMath)); |
1302 | |
1303 | if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath()) |
1304 | llvmFunc->addFnAttr("no-signed-zeros-fp-math" , |
1305 | llvm::toStringRef(*noSignedZerosFpMath)); |
1306 | |
1307 | // Add function attribute frame-pointer, if found. |
1308 | if (FramePointerKindAttr attr = func.getFramePointerAttr()) |
1309 | llvmFunc->addFnAttr("frame-pointer" , |
1310 | LLVM::framePointerKind::stringifyFramePointerKind( |
1311 | (attr.getFramePointerKind()))); |
1312 | |
1313 | // First, create all blocks so we can jump to them. |
1314 | llvm::LLVMContext &llvmContext = llvmFunc->getContext(); |
1315 | for (auto &bb : func) { |
1316 | auto *llvmBB = llvm::BasicBlock::Create(llvmContext); |
1317 | llvmBB->insertInto(llvmFunc); |
1318 | mapBlock(&bb, llvmBB); |
1319 | } |
1320 | |
1321 | // Then, convert blocks one by one in topological order to ensure defs are |
1322 | // converted before uses. |
1323 | auto blocks = getTopologicallySortedBlocks(func.getBody()); |
1324 | for (Block *bb : blocks) { |
1325 | CapturingIRBuilder builder(llvmContext); |
1326 | if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder, |
1327 | /*recordInsertions=*/true))) |
1328 | return failure(); |
1329 | } |
1330 | |
1331 | // After all blocks have been traversed and values mapped, connect the PHI |
1332 | // nodes to the results of preceding blocks. |
1333 | detail::connectPHINodes(region&: func.getBody(), state: *this); |
1334 | |
1335 | // Finally, convert dialect attributes attached to the function. |
1336 | return convertDialectAttributes(op: func, instructions: {}); |
1337 | } |
1338 | |
1339 | LogicalResult ModuleTranslation::convertDialectAttributes( |
1340 | Operation *op, ArrayRef<llvm::Instruction *> instructions) { |
1341 | for (NamedAttribute attribute : op->getDialectAttrs()) |
1342 | if (failed(result: iface.amendOperation(op, instructions, attribute, moduleTranslation&: *this))) |
1343 | return failure(); |
1344 | return success(); |
1345 | } |
1346 | |
1347 | /// Converts the function attributes from LLVMFuncOp and attaches them to the |
1348 | /// llvm::Function. |
1349 | static void convertFunctionAttributes(LLVMFuncOp func, |
1350 | llvm::Function *llvmFunc) { |
1351 | if (!func.getMemory()) |
1352 | return; |
1353 | |
1354 | MemoryEffectsAttr memEffects = func.getMemoryAttr(); |
1355 | |
1356 | // Add memory effects incrementally. |
1357 | llvm::MemoryEffects newMemEffects = |
1358 | llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, |
1359 | convertModRefInfoToLLVM(memEffects.getArgMem())); |
1360 | newMemEffects |= llvm::MemoryEffects( |
1361 | llvm::MemoryEffects::Location::InaccessibleMem, |
1362 | convertModRefInfoToLLVM(memEffects.getInaccessibleMem())); |
1363 | newMemEffects |= |
1364 | llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, |
1365 | convertModRefInfoToLLVM(memEffects.getOther())); |
1366 | llvmFunc->setMemoryEffects(newMemEffects); |
1367 | } |
1368 | |
1369 | FailureOr<llvm::AttrBuilder> |
1370 | ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, |
1371 | DictionaryAttr paramAttrs) { |
1372 | llvm::AttrBuilder attrBuilder(llvmModule->getContext()); |
1373 | auto attrNameToKindMapping = getAttrNameToKindMapping(); |
1374 | |
1375 | for (auto namedAttr : paramAttrs) { |
1376 | auto it = attrNameToKindMapping.find(namedAttr.getName()); |
1377 | if (it != attrNameToKindMapping.end()) { |
1378 | llvm::Attribute::AttrKind llvmKind = it->second; |
1379 | |
1380 | llvm::TypeSwitch<Attribute>(namedAttr.getValue()) |
1381 | .Case<TypeAttr>([&](auto typeAttr) { |
1382 | attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue())); |
1383 | }) |
1384 | .Case<IntegerAttr>([&](auto intAttr) { |
1385 | attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); |
1386 | }) |
1387 | .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); }); |
1388 | } else if (namedAttr.getNameDialect()) { |
1389 | if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this))) |
1390 | return failure(); |
1391 | } |
1392 | } |
1393 | |
1394 | return attrBuilder; |
1395 | } |
1396 | |
1397 | LogicalResult ModuleTranslation::convertFunctionSignatures() { |
1398 | // Declare all functions first because there may be function calls that form a |
1399 | // call graph with cycles, or global initializers that reference functions. |
1400 | for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { |
1401 | llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( |
1402 | function.getName(), |
1403 | cast<llvm::FunctionType>(convertType(function.getFunctionType()))); |
1404 | llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee()); |
1405 | llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage())); |
1406 | llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv())); |
1407 | mapFunction(function.getName(), llvmFunc); |
1408 | addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc); |
1409 | |
1410 | // Convert function attributes. |
1411 | convertFunctionAttributes(function, llvmFunc); |
1412 | |
1413 | // Convert function_entry_count attribute to metadata. |
1414 | if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount()) |
1415 | llvmFunc->setEntryCount(entryCount.value()); |
1416 | |
1417 | // Convert result attributes. |
1418 | if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) { |
1419 | DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]); |
1420 | FailureOr<llvm::AttrBuilder> attrBuilder = |
1421 | convertParameterAttrs(function, -1, resultAttrs); |
1422 | if (failed(attrBuilder)) |
1423 | return failure(); |
1424 | llvmFunc->addRetAttrs(*attrBuilder); |
1425 | } |
1426 | |
1427 | // Convert argument attributes. |
1428 | for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) { |
1429 | if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) { |
1430 | FailureOr<llvm::AttrBuilder> attrBuilder = |
1431 | convertParameterAttrs(function, argIdx, argAttrs); |
1432 | if (failed(attrBuilder)) |
1433 | return failure(); |
1434 | llvmArg.addAttrs(*attrBuilder); |
1435 | } |
1436 | } |
1437 | |
1438 | // Forward the pass-through attributes to LLVM. |
1439 | if (failed(forwardPassthroughAttributes( |
1440 | function.getLoc(), function.getPassthrough(), llvmFunc))) |
1441 | return failure(); |
1442 | |
1443 | // Convert visibility attribute. |
1444 | llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_())); |
1445 | |
1446 | // Convert the comdat attribute. |
1447 | if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) { |
1448 | auto selectorOp = cast<ComdatSelectorOp>( |
1449 | SymbolTable::lookupNearestSymbolFrom(function, *comdat)); |
1450 | llvmFunc->setComdat(comdatMapping.lookup(selectorOp)); |
1451 | } |
1452 | |
1453 | if (auto gc = function.getGarbageCollector()) |
1454 | llvmFunc->setGC(gc->str()); |
1455 | |
1456 | if (auto unnamedAddr = function.getUnnamedAddr()) |
1457 | llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr)); |
1458 | |
1459 | if (auto alignment = function.getAlignment()) |
1460 | llvmFunc->setAlignment(llvm::MaybeAlign(*alignment)); |
1461 | |
1462 | // Translate the debug information for this function. |
1463 | debugTranslation->translate(function, *llvmFunc); |
1464 | } |
1465 | |
1466 | return success(); |
1467 | } |
1468 | |
1469 | LogicalResult ModuleTranslation::convertFunctions() { |
1470 | // Convert functions. |
1471 | for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { |
1472 | // Do not convert external functions, but do process dialect attributes |
1473 | // attached to them. |
1474 | if (function.isExternal()) { |
1475 | if (failed(convertDialectAttributes(function, {}))) |
1476 | return failure(); |
1477 | continue; |
1478 | } |
1479 | |
1480 | if (failed(convertOneFunction(function))) |
1481 | return failure(); |
1482 | } |
1483 | |
1484 | return success(); |
1485 | } |
1486 | |
1487 | LogicalResult ModuleTranslation::convertComdats() { |
1488 | for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) { |
1489 | for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) { |
1490 | llvm::Module *module = getLLVMModule(); |
1491 | if (module->getComdatSymbolTable().contains(selectorOp.getSymName())) |
1492 | return emitError(selectorOp.getLoc()) |
1493 | << "comdat selection symbols must be unique even in different " |
1494 | "comdat regions" ; |
1495 | llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName()); |
1496 | comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat())); |
1497 | comdatMapping.try_emplace(selectorOp, comdat); |
1498 | } |
1499 | } |
1500 | return success(); |
1501 | } |
1502 | |
1503 | void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op, |
1504 | llvm::Instruction *inst) { |
1505 | if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op)) |
1506 | inst->setMetadata(KindID: llvm::LLVMContext::MD_access_group, Node: node); |
1507 | } |
1508 | |
1509 | llvm::MDNode * |
1510 | ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) { |
1511 | auto [scopeIt, scopeInserted] = |
1512 | aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr); |
1513 | if (!scopeInserted) |
1514 | return scopeIt->second; |
1515 | llvm::LLVMContext &ctx = llvmModule->getContext(); |
1516 | auto dummy = llvm::MDNode::getTemporary(Context&: ctx, MDs: std::nullopt); |
1517 | // Convert the domain metadata node if necessary. |
1518 | auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace( |
1519 | aliasScopeAttr.getDomain(), nullptr); |
1520 | if (insertedDomain) { |
1521 | llvm::SmallVector<llvm::Metadata *, 2> operands; |
1522 | // Placeholder for self-reference. |
1523 | operands.push_back(Elt: dummy.get()); |
1524 | if (StringAttr description = aliasScopeAttr.getDomain().getDescription()) |
1525 | operands.push_back(Elt: llvm::MDString::get(ctx, description)); |
1526 | domainIt->second = llvm::MDNode::get(ctx, operands); |
1527 | // Self-reference for uniqueness. |
1528 | domainIt->second->replaceOperandWith(0, domainIt->second); |
1529 | } |
1530 | // Convert the scope metadata node. |
1531 | assert(domainIt->second && "Scope's domain should already be valid" ); |
1532 | llvm::SmallVector<llvm::Metadata *, 3> operands; |
1533 | // Placeholder for self-reference. |
1534 | operands.push_back(Elt: dummy.get()); |
1535 | operands.push_back(Elt: domainIt->second); |
1536 | if (StringAttr description = aliasScopeAttr.getDescription()) |
1537 | operands.push_back(Elt: llvm::MDString::get(ctx, description)); |
1538 | scopeIt->second = llvm::MDNode::get(ctx, operands); |
1539 | // Self-reference for uniqueness. |
1540 | scopeIt->second->replaceOperandWith(0, scopeIt->second); |
1541 | return scopeIt->second; |
1542 | } |
1543 | |
1544 | llvm::MDNode *ModuleTranslation::getOrCreateAliasScopes( |
1545 | ArrayRef<AliasScopeAttr> aliasScopeAttrs) { |
1546 | SmallVector<llvm::Metadata *> nodes; |
1547 | nodes.reserve(N: aliasScopeAttrs.size()); |
1548 | for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs) |
1549 | nodes.push_back(getOrCreateAliasScope(aliasScopeAttr)); |
1550 | return llvm::MDNode::get(Context&: getLLVMContext(), MDs: nodes); |
1551 | } |
1552 | |
1553 | void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op, |
1554 | llvm::Instruction *inst) { |
1555 | auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) { |
1556 | if (!aliasScopeAttrs || aliasScopeAttrs.empty()) |
1557 | return; |
1558 | llvm::MDNode *node = getOrCreateAliasScopes( |
1559 | aliasScopeAttrs: llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>())); |
1560 | inst->setMetadata(KindID: kind, Node: node); |
1561 | }; |
1562 | |
1563 | populateScopeMetadata(op.getAliasScopesOrNull(), |
1564 | llvm::LLVMContext::MD_alias_scope); |
1565 | populateScopeMetadata(op.getNoAliasScopesOrNull(), |
1566 | llvm::LLVMContext::MD_noalias); |
1567 | } |
1568 | |
1569 | llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const { |
1570 | return tbaaMetadataMapping.lookup(Val: tbaaAttr); |
1571 | } |
1572 | |
1573 | void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op, |
1574 | llvm::Instruction *inst) { |
1575 | ArrayAttr tagRefs = op.getTBAATagsOrNull(); |
1576 | if (!tagRefs || tagRefs.empty()) |
1577 | return; |
1578 | |
1579 | // LLVM IR currently does not support attaching more than one TBAA access tag |
1580 | // to a memory accessing instruction. It may be useful to support this in |
1581 | // future, but for the time being just ignore the metadata if MLIR operation |
1582 | // has multiple access tags. |
1583 | if (tagRefs.size() > 1) { |
1584 | op.emitWarning() << "TBAA access tags were not translated, because LLVM " |
1585 | "IR only supports a single tag per instruction" ; |
1586 | return; |
1587 | } |
1588 | |
1589 | llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0])); |
1590 | inst->setMetadata(KindID: llvm::LLVMContext::MD_tbaa, Node: node); |
1591 | } |
1592 | |
1593 | void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { |
1594 | DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull(); |
1595 | if (!weightsAttr) |
1596 | return; |
1597 | |
1598 | llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op); |
1599 | assert(inst && "expected the operation to have a mapping to an instruction" ); |
1600 | SmallVector<uint32_t> weights(weightsAttr.asArrayRef()); |
1601 | inst->setMetadata( |
1602 | KindID: llvm::LLVMContext::MD_prof, |
1603 | Node: llvm::MDBuilder(getLLVMContext()).createBranchWeights(Weights: weights)); |
1604 | } |
1605 | |
1606 | LogicalResult ModuleTranslation::createTBAAMetadata() { |
1607 | llvm::LLVMContext &ctx = llvmModule->getContext(); |
1608 | llvm::IntegerType *offsetTy = llvm::IntegerType::get(C&: ctx, NumBits: 64); |
1609 | |
1610 | // Walk the entire module and create all metadata nodes for the TBAA |
1611 | // attributes. The code below relies on two invariants of the |
1612 | // `AttrTypeWalker`: |
1613 | // 1. Attributes are visited in post-order: Since the attributes create a DAG, |
1614 | // this ensures that any lookups into `tbaaMetadataMapping` for child |
1615 | // attributes succeed. |
1616 | // 2. Attributes are only ever visited once: This way we don't leak any |
1617 | // LLVM metadata instances. |
1618 | AttrTypeWalker walker; |
1619 | walker.addWalk(callback: [&](TBAARootAttr root) { |
1620 | tbaaMetadataMapping.insert( |
1621 | {root, llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(ctx, root.getId()))}); |
1622 | }); |
1623 | |
1624 | walker.addWalk(callback: [&](TBAATypeDescriptorAttr descriptor) { |
1625 | SmallVector<llvm::Metadata *> operands; |
1626 | operands.push_back(Elt: llvm::MDString::get(ctx, descriptor.getId())); |
1627 | for (TBAAMemberAttr member : descriptor.getMembers()) { |
1628 | operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc())); |
1629 | operands.push_back(llvm::ConstantAsMetadata::get( |
1630 | llvm::ConstantInt::get(offsetTy, member.getOffset()))); |
1631 | } |
1632 | |
1633 | tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(Context&: ctx, MDs: operands)}); |
1634 | }); |
1635 | |
1636 | walker.addWalk(callback: [&](TBAATagAttr tag) { |
1637 | SmallVector<llvm::Metadata *> operands; |
1638 | |
1639 | operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: tag.getBaseType())); |
1640 | operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: tag.getAccessType())); |
1641 | |
1642 | operands.push_back(Elt: llvm::ConstantAsMetadata::get( |
1643 | C: llvm::ConstantInt::get(offsetTy, tag.getOffset()))); |
1644 | if (tag.getConstant()) |
1645 | operands.push_back( |
1646 | Elt: llvm::ConstantAsMetadata::get(C: llvm::ConstantInt::get(Ty: offsetTy, V: 1))); |
1647 | |
1648 | tbaaMetadataMapping.insert({tag, llvm::MDNode::get(Context&: ctx, MDs: operands)}); |
1649 | }); |
1650 | |
1651 | mlirModule->walk(callback: [&](AliasAnalysisOpInterface analysisOpInterface) { |
1652 | if (auto attr = analysisOpInterface.getTBAATagsOrNull()) |
1653 | walker.walk(attr); |
1654 | }); |
1655 | |
1656 | return success(); |
1657 | } |
1658 | |
1659 | void ModuleTranslation::setLoopMetadata(Operation *op, |
1660 | llvm::Instruction *inst) { |
1661 | LoopAnnotationAttr attr = |
1662 | TypeSwitch<Operation *, LoopAnnotationAttr>(op) |
1663 | .Case<LLVM::BrOp, LLVM::CondBrOp>( |
1664 | [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); }); |
1665 | if (!attr) |
1666 | return; |
1667 | llvm::MDNode *loopMD = |
1668 | loopAnnotationTranslation->translateLoopAnnotation(attr, op); |
1669 | inst->setMetadata(KindID: llvm::LLVMContext::MD_loop, Node: loopMD); |
1670 | } |
1671 | |
1672 | llvm::Type *ModuleTranslation::convertType(Type type) { |
1673 | return typeTranslator.translateType(type); |
1674 | } |
1675 | |
1676 | /// A helper to look up remapped operands in the value remapping table. |
1677 | SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) { |
1678 | SmallVector<llvm::Value *> remapped; |
1679 | remapped.reserve(N: values.size()); |
1680 | for (Value v : values) |
1681 | remapped.push_back(Elt: lookupValue(value: v)); |
1682 | return remapped; |
1683 | } |
1684 | |
1685 | llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() { |
1686 | if (!ompBuilder) { |
1687 | ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(args&: *llvmModule); |
1688 | ompBuilder->initialize(); |
1689 | |
1690 | // Flags represented as top-level OpenMP dialect attributes are set in |
1691 | // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set |
1692 | // the default configuration. |
1693 | ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig( |
1694 | /* IsTargetDevice = */ false, /* IsGPU = */ false, |
1695 | /* OpenMPOffloadMandatory = */ false, |
1696 | /* HasRequiresReverseOffload = */ false, |
1697 | /* HasRequiresUnifiedAddress = */ false, |
1698 | /* HasRequiresUnifiedSharedMemory = */ false, |
1699 | /* HasRequiresDynamicAllocators = */ false)); |
1700 | } |
1701 | return ompBuilder.get(); |
1702 | } |
1703 | |
1704 | llvm::DILocation *ModuleTranslation::translateLoc(Location loc, |
1705 | llvm::DILocalScope *scope) { |
1706 | return debugTranslation->translateLoc(loc, scope); |
1707 | } |
1708 | |
1709 | llvm::DIExpression * |
1710 | ModuleTranslation::translateExpression(LLVM::DIExpressionAttr attr) { |
1711 | return debugTranslation->translateExpression(attr); |
1712 | } |
1713 | |
1714 | llvm::DIGlobalVariableExpression * |
1715 | ModuleTranslation::translateGlobalVariableExpression( |
1716 | LLVM::DIGlobalVariableExpressionAttr attr) { |
1717 | return debugTranslation->translateGlobalVariableExpression(attr); |
1718 | } |
1719 | |
1720 | llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) { |
1721 | return debugTranslation->translate(attr); |
1722 | } |
1723 | |
1724 | llvm::RoundingMode |
1725 | ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) { |
1726 | return convertRoundingModeToLLVM(rounding); |
1727 | } |
1728 | |
1729 | llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior( |
1730 | LLVM::FPExceptionBehavior exceptionBehavior) { |
1731 | return convertFPExceptionBehaviorToLLVM(exceptionBehavior); |
1732 | } |
1733 | |
1734 | llvm::NamedMDNode * |
1735 | ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) { |
1736 | return llvmModule->getOrInsertNamedMetadata(Name: name); |
1737 | } |
1738 | |
1739 | void ModuleTranslation::StackFrame::anchor() {} |
1740 | |
1741 | static std::unique_ptr<llvm::Module> |
1742 | prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, |
1743 | StringRef name) { |
1744 | m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>(); |
1745 | auto llvmModule = std::make_unique<llvm::Module>(args&: name, args&: llvmContext); |
1746 | if (auto dataLayoutAttr = |
1747 | m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) { |
1748 | llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue()); |
1749 | } else { |
1750 | FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout("" )); |
1751 | if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) { |
1752 | if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) { |
1753 | llvmDataLayout = |
1754 | translateDataLayout(spec, DataLayout(iface), m->getLoc()); |
1755 | } |
1756 | } else if (auto mod = dyn_cast<ModuleOp>(m)) { |
1757 | if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) { |
1758 | llvmDataLayout = |
1759 | translateDataLayout(spec, DataLayout(mod), m->getLoc()); |
1760 | } |
1761 | } |
1762 | if (failed(result: llvmDataLayout)) |
1763 | return nullptr; |
1764 | llvmModule->setDataLayout(*llvmDataLayout); |
1765 | } |
1766 | if (auto targetTripleAttr = |
1767 | m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) |
1768 | llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue()); |
1769 | |
1770 | return llvmModule; |
1771 | } |
1772 | |
1773 | std::unique_ptr<llvm::Module> |
1774 | mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, |
1775 | StringRef name) { |
1776 | if (!satisfiesLLVMModule(op: module)) { |
1777 | module->emitOpError(message: "can not be translated to an LLVMIR module" ); |
1778 | return nullptr; |
1779 | } |
1780 | |
1781 | std::unique_ptr<llvm::Module> llvmModule = |
1782 | prepareLLVMModule(m: module, llvmContext, name); |
1783 | if (!llvmModule) |
1784 | return nullptr; |
1785 | |
1786 | LLVM::ensureDistinctSuccessors(op: module); |
1787 | LLVM::legalizeDIExpressionsRecursively(op: module); |
1788 | |
1789 | ModuleTranslation translator(module, std::move(llvmModule)); |
1790 | llvm::IRBuilder<> llvmBuilder(llvmContext); |
1791 | |
1792 | // Convert module before functions and operations inside, so dialect |
1793 | // attributes can be used to change dialect-specific global configurations via |
1794 | // `amendOperation()`. These configurations can then influence the translation |
1795 | // of operations afterwards. |
1796 | if (failed(result: translator.convertOperation(op&: *module, builder&: llvmBuilder))) |
1797 | return nullptr; |
1798 | |
1799 | if (failed(result: translator.convertComdats())) |
1800 | return nullptr; |
1801 | if (failed(result: translator.convertFunctionSignatures())) |
1802 | return nullptr; |
1803 | if (failed(result: translator.convertGlobals())) |
1804 | return nullptr; |
1805 | if (failed(result: translator.createTBAAMetadata())) |
1806 | return nullptr; |
1807 | |
1808 | // Convert other top-level operations if possible. |
1809 | for (Operation &o : getModuleBody(module).getOperations()) { |
1810 | if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp, |
1811 | LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) && |
1812 | !o.hasTrait<OpTrait::IsTerminator>() && |
1813 | failed(translator.convertOperation(o, llvmBuilder))) { |
1814 | return nullptr; |
1815 | } |
1816 | } |
1817 | |
1818 | // Operations in function bodies with symbolic references must be converted |
1819 | // after the top-level operations they refer to are declared, so we do it |
1820 | // last. |
1821 | if (failed(result: translator.convertFunctions())) |
1822 | return nullptr; |
1823 | |
1824 | if (llvm::verifyModule(M: *translator.llvmModule, OS: &llvm::errs())) |
1825 | return nullptr; |
1826 | |
1827 | return std::move(translator.llvmModule); |
1828 | } |
1829 | |