1//===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the translation between an MLIR LLVM dialect module and
10// the corresponding LLVMIR module. It only handles core LLVM IR operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Target/LLVMIR/ModuleTranslation.h"
15
16#include "AttrKindDetail.h"
17#include "DebugTranslation.h"
18#include "LoopAnnotationTranslation.h"
19#include "mlir/Analysis/TopologicalSortUtils.h"
20#include "mlir/Dialect/DLTI/DLTI.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
23#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
24#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
25#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
26#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
27#include "mlir/IR/AttrTypeSubElements.h"
28#include "mlir/IR/Attributes.h"
29#include "mlir/IR/BuiltinOps.h"
30#include "mlir/IR/BuiltinTypes.h"
31#include "mlir/IR/DialectResourceBlobManager.h"
32#include "mlir/IR/RegionGraphTraits.h"
33#include "mlir/Support/LLVM.h"
34#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
35#include "mlir/Target/LLVMIR/TypeToLLVM.h"
36
37#include "llvm/ADT/PostOrderIterator.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/SetVector.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/TypeSwitch.h"
42#include "llvm/Analysis/TargetFolder.h"
43#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
44#include "llvm/IR/BasicBlock.h"
45#include "llvm/IR/CFG.h"
46#include "llvm/IR/Constants.h"
47#include "llvm/IR/DerivedTypes.h"
48#include "llvm/IR/IRBuilder.h"
49#include "llvm/IR/InlineAsm.h"
50#include "llvm/IR/IntrinsicsNVPTX.h"
51#include "llvm/IR/LLVMContext.h"
52#include "llvm/IR/MDBuilder.h"
53#include "llvm/IR/Module.h"
54#include "llvm/IR/Verifier.h"
55#include "llvm/Support/Debug.h"
56#include "llvm/Support/ErrorHandling.h"
57#include "llvm/Support/raw_ostream.h"
58#include "llvm/Transforms/Utils/BasicBlockUtils.h"
59#include "llvm/Transforms/Utils/Cloning.h"
60#include "llvm/Transforms/Utils/ModuleUtils.h"
61#include <numeric>
62#include <optional>
63
64#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
65
66using namespace mlir;
67using namespace mlir::LLVM;
68using namespace mlir::LLVM::detail;
69
70#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
71
72namespace {
73/// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
74/// instructions that are created for future reference.
75///
76/// This is intended to be used with the `CollectionScope` RAII object:
77///
78/// llvm::IRBuilder<..., InstructionCapturingInserter> builder;
79/// {
80/// InstructionCapturingInserter::CollectionScope scope(builder);
81/// // Call IRBuilder methods as usual.
82///
83/// // This will return a list of all instructions created by the builder,
84/// // in order of creation.
85/// builder.getInserter().getCapturedInstructions();
86/// }
87/// // This will return an empty list.
88/// builder.getInserter().getCapturedInstructions();
89///
90/// The capturing functionality is _disabled_ by default for performance
91/// consideration. It needs to be explicitly enabled, which is achieved by
92/// creating a `CollectionScope`.
93class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
94public:
95 /// Constructs the inserter.
96 InstructionCapturingInserter()
97 : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
98 if (LLVM_LIKELY(enabled))
99 capturedInstructions.push_back(instruction);
100 }) {}
101
102 /// Returns the list of LLVM IR instructions captured since the last cleanup.
103 ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
104 return capturedInstructions;
105 }
106
107 /// Clears the list of captured LLVM IR instructions.
108 void clearCapturedInstructions() { capturedInstructions.clear(); }
109
110 /// RAII object enabling the capture of created LLVM IR instructions.
111 class CollectionScope {
112 public:
113 /// Creates the scope for the given inserter.
114 CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing);
115
116 /// Ends the scope.
117 ~CollectionScope();
118
119 ArrayRef<llvm::Instruction *> getCapturedInstructions() {
120 if (!inserter)
121 return {};
122 return inserter->getCapturedInstructions();
123 }
124
125 private:
126 /// Back reference to the inserter.
127 InstructionCapturingInserter *inserter = nullptr;
128
129 /// List of instructions in the inserter prior to this scope.
130 SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
131
132 /// Whether the inserter was enabled prior to this scope.
133 bool wasEnabled;
134 };
135
136 /// Enable or disable the capturing mechanism.
137 void setEnabled(bool enabled = true) { this->enabled = enabled; }
138
139private:
140 /// List of captured instructions.
141 SmallVector<llvm::Instruction *> capturedInstructions;
142
143 /// Whether the collection is enabled.
144 bool enabled = false;
145};
146
147using CapturingIRBuilder =
148 llvm::IRBuilder<llvm::TargetFolder, InstructionCapturingInserter>;
149} // namespace
150
151InstructionCapturingInserter::CollectionScope::CollectionScope(
152 llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) {
153
154 if (!isBuilderCapturing)
155 return;
156
157 auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder);
158 inserter = &capturingIRBuilder.getInserter();
159 wasEnabled = inserter->enabled;
160 if (wasEnabled)
161 previouslyCollectedInstructions.swap(inserter->capturedInstructions);
162 inserter->setEnabled(true);
163}
164
165InstructionCapturingInserter::CollectionScope::~CollectionScope() {
166 if (!inserter)
167 return;
168
169 previouslyCollectedInstructions.swap(inserter->capturedInstructions);
170 // If collection was enabled (likely in another, surrounding scope), keep
171 // the instructions collected in this scope.
172 if (wasEnabled) {
173 llvm::append_range(inserter->capturedInstructions,
174 previouslyCollectedInstructions);
175 }
176 inserter->setEnabled(wasEnabled);
177}
178
179/// Translates the given data layout spec attribute to the LLVM IR data layout.
180/// Only integer, float, pointer and endianness entries are currently supported.
181static FailureOr<llvm::DataLayout>
182translateDataLayout(DataLayoutSpecInterface attribute,
183 const DataLayout &dataLayout,
184 std::optional<Location> loc = std::nullopt) {
185 if (!loc)
186 loc = UnknownLoc::get(attribute.getContext());
187
188 // Translate the endianness attribute.
189 std::string llvmDataLayout;
190 llvm::raw_string_ostream layoutStream(llvmDataLayout);
191 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
192 auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
193 if (!key)
194 continue;
195 if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
196 auto value = cast<StringAttr>(entry.getValue());
197 bool isLittleEndian =
198 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
199 layoutStream << "-" << (isLittleEndian ? "e" : "E");
200 continue;
201 }
202 if (key.getValue() == DLTIDialect::kDataLayoutManglingModeKey) {
203 auto value = cast<StringAttr>(entry.getValue());
204 layoutStream << "-m:" << value.getValue();
205 continue;
206 }
207 if (key.getValue() == DLTIDialect::kDataLayoutProgramMemorySpaceKey) {
208 auto value = cast<IntegerAttr>(entry.getValue());
209 uint64_t space = value.getValue().getZExtValue();
210 // Skip the default address space.
211 if (space == 0)
212 continue;
213 layoutStream << "-P" << space;
214 continue;
215 }
216 if (key.getValue() == DLTIDialect::kDataLayoutGlobalMemorySpaceKey) {
217 auto value = cast<IntegerAttr>(entry.getValue());
218 uint64_t space = value.getValue().getZExtValue();
219 // Skip the default address space.
220 if (space == 0)
221 continue;
222 layoutStream << "-G" << space;
223 continue;
224 }
225 if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
226 auto value = cast<IntegerAttr>(entry.getValue());
227 uint64_t space = value.getValue().getZExtValue();
228 // Skip the default address space.
229 if (space == 0)
230 continue;
231 layoutStream << "-A" << space;
232 continue;
233 }
234 if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
235 auto value = cast<IntegerAttr>(entry.getValue());
236 uint64_t alignment = value.getValue().getZExtValue();
237 // Skip the default stack alignment.
238 if (alignment == 0)
239 continue;
240 layoutStream << "-S" << alignment;
241 continue;
242 }
243 if (key.getValue() == DLTIDialect::kDataLayoutFunctionPointerAlignmentKey) {
244 auto value = cast<FunctionPointerAlignmentAttr>(entry.getValue());
245 uint64_t alignment = value.getAlignment();
246 // Skip the default function pointer alignment.
247 if (alignment == 0)
248 continue;
249 layoutStream << "-F" << (value.getFunctionDependent() ? "n" : "i")
250 << alignment;
251 continue;
252 }
253 if (key.getValue() == DLTIDialect::kDataLayoutLegalIntWidthsKey) {
254 layoutStream << "-n";
255 llvm::interleave(
256 cast<DenseI32ArrayAttr>(entry.getValue()).asArrayRef(), layoutStream,
257 [&](int32_t val) { layoutStream << val; }, ":");
258 continue;
259 }
260 emitError(*loc) << "unsupported data layout key " << key;
261 return failure();
262 }
263
264 // Go through the list of entries to check which types are explicitly
265 // specified in entries. Where possible, data layout queries are used instead
266 // of directly inspecting the entries.
267 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
268 auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
269 if (!type)
270 continue;
271 // Data layout for the index type is irrelevant at this point.
272 if (isa<IndexType>(type))
273 continue;
274 layoutStream << "-";
275 LogicalResult result =
276 llvm::TypeSwitch<Type, LogicalResult>(type)
277 .Case<IntegerType, Float16Type, Float32Type, Float64Type,
278 Float80Type, Float128Type>([&](Type type) -> LogicalResult {
279 if (auto intType = dyn_cast<IntegerType>(type)) {
280 if (intType.getSignedness() != IntegerType::Signless)
281 return emitError(*loc)
282 << "unsupported data layout for non-signless integer "
283 << intType;
284 layoutStream << "i";
285 } else {
286 layoutStream << "f";
287 }
288 uint64_t size = dataLayout.getTypeSizeInBits(type);
289 uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u;
290 uint64_t preferred =
291 dataLayout.getTypePreferredAlignment(type) * 8u;
292 layoutStream << size << ":" << abi;
293 if (abi != preferred)
294 layoutStream << ":" << preferred;
295 return success();
296 })
297 .Case([&](LLVMPointerType type) {
298 layoutStream << "p" << type.getAddressSpace() << ":";
299 uint64_t size = dataLayout.getTypeSizeInBits(type);
300 uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u;
301 uint64_t preferred =
302 dataLayout.getTypePreferredAlignment(type) * 8u;
303 uint64_t index = *dataLayout.getTypeIndexBitwidth(type);
304 layoutStream << size << ":" << abi << ":" << preferred << ":"
305 << index;
306 return success();
307 })
308 .Default([loc](Type type) {
309 return emitError(*loc)
310 << "unsupported type in data layout: " << type;
311 });
312 if (failed(result))
313 return failure();
314 }
315 StringRef layoutSpec(llvmDataLayout);
316 layoutSpec.consume_front(Prefix: "-");
317
318 return llvm::DataLayout(layoutSpec);
319}
320
321/// Builds a constant of a sequential LLVM type `type`, potentially containing
322/// other sequential types recursively, from the individual constant values
323/// provided in `constants`. `shape` contains the number of elements in nested
324/// sequential types. Reports errors at `loc` and returns nullptr on error.
325static llvm::Constant *
326buildSequentialConstant(ArrayRef<llvm::Constant *> &constants,
327 ArrayRef<int64_t> shape, llvm::Type *type,
328 Location loc) {
329 if (shape.empty()) {
330 llvm::Constant *result = constants.front();
331 constants = constants.drop_front();
332 return result;
333 }
334
335 llvm::Type *elementType;
336 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: type)) {
337 elementType = arrayTy->getElementType();
338 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(Val: type)) {
339 elementType = vectorTy->getElementType();
340 } else {
341 emitError(loc) << "expected sequential LLVM types wrapping a scalar";
342 return nullptr;
343 }
344
345 SmallVector<llvm::Constant *, 8> nested;
346 nested.reserve(N: shape.front());
347 for (int64_t i = 0; i < shape.front(); ++i) {
348 nested.push_back(Elt: buildSequentialConstant(constants, shape: shape.drop_front(),
349 type: elementType, loc));
350 if (!nested.back())
351 return nullptr;
352 }
353
354 if (shape.size() == 1 && type->isVectorTy())
355 return llvm::ConstantVector::get(V: nested);
356 return llvm::ConstantArray::get(
357 T: llvm::ArrayType::get(ElementType: elementType, NumElements: shape.front()), V: nested);
358}
359
360/// Returns the first non-sequential type nested in sequential types.
361static llvm::Type *getInnermostElementType(llvm::Type *type) {
362 do {
363 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: type)) {
364 type = arrayTy->getElementType();
365 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(Val: type)) {
366 type = vectorTy->getElementType();
367 } else {
368 return type;
369 }
370 } while (true);
371}
372
373/// Convert a dense elements attribute to an LLVM IR constant using its raw data
374/// storage if possible. This supports elements attributes of tensor or vector
375/// type and avoids constructing separate objects for individual values of the
376/// innermost dimension. Constants for other dimensions are still constructed
377/// recursively. Returns null if constructing from raw data is not supported for
378/// this type, e.g., element type is not a power-of-two-sized primitive. Reports
379/// other errors at `loc`.
380static llvm::Constant *
381convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
382 llvm::Type *llvmType,
383 const ModuleTranslation &moduleTranslation) {
384 if (!denseElementsAttr)
385 return nullptr;
386
387 llvm::Type *innermostLLVMType = getInnermostElementType(type: llvmType);
388 if (!llvm::ConstantDataSequential::isElementTypeCompatible(Ty: innermostLLVMType))
389 return nullptr;
390
391 ShapedType type = denseElementsAttr.getType();
392 if (type.getNumElements() == 0)
393 return nullptr;
394
395 // Check that the raw data size matches what is expected for the scalar size.
396 // TODO: in theory, we could repack the data here to keep constructing from
397 // raw data.
398 // TODO: we may also need to consider endianness when cross-compiling to an
399 // architecture where it is different.
400 int64_t elementByteSize = denseElementsAttr.getRawData().size() /
401 denseElementsAttr.getNumElements();
402 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits())
403 return nullptr;
404
405 // Compute the shape of all dimensions but the innermost. Note that the
406 // innermost dimension may be that of the vector element type.
407 bool hasVectorElementType = isa<VectorType>(type.getElementType());
408 int64_t numAggregates =
409 denseElementsAttr.getNumElements() /
410 (hasVectorElementType ? 1
411 : denseElementsAttr.getType().getShape().back());
412 ArrayRef<int64_t> outerShape = type.getShape();
413 if (!hasVectorElementType)
414 outerShape = outerShape.drop_back();
415
416 // Handle the case of vector splat, LLVM has special support for it.
417 if (denseElementsAttr.isSplat() &&
418 (isa<VectorType>(type) || hasVectorElementType)) {
419 llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
420 llvmType: innermostLLVMType, attr: denseElementsAttr.getSplatValue<Attribute>(), loc,
421 moduleTranslation);
422 llvm::Constant *splatVector =
423 llvm::ConstantDataVector::getSplat(NumElts: 0, Elt: splatValue);
424 SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
425 ArrayRef<llvm::Constant *> constantsRef = constants;
426 return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc);
427 }
428 if (denseElementsAttr.isSplat())
429 return nullptr;
430
431 // In case of non-splat, create a constructor for the innermost constant from
432 // a piece of raw data.
433 std::function<llvm::Constant *(StringRef)> buildCstData;
434 if (isa<TensorType>(type)) {
435 auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
436 if (vectorElementType && vectorElementType.getRank() == 1) {
437 buildCstData = [&](StringRef data) {
438 return llvm::ConstantDataVector::getRaw(
439 data, vectorElementType.getShape().back(), innermostLLVMType);
440 };
441 } else if (!vectorElementType) {
442 buildCstData = [&](StringRef data) {
443 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
444 innermostLLVMType);
445 };
446 }
447 } else if (isa<VectorType>(type)) {
448 buildCstData = [&](StringRef data) {
449 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
450 innermostLLVMType);
451 };
452 }
453 if (!buildCstData)
454 return nullptr;
455
456 // Create innermost constants and defer to the default constant creation
457 // mechanism for other dimensions.
458 SmallVector<llvm::Constant *> constants;
459 int64_t aggregateSize = denseElementsAttr.getType().getShape().back() *
460 (innermostLLVMType->getScalarSizeInBits() / 8);
461 constants.reserve(N: numAggregates);
462 for (unsigned i = 0; i < numAggregates; ++i) {
463 StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
464 aggregateSize);
465 constants.push_back(Elt: buildCstData(data));
466 }
467
468 ArrayRef<llvm::Constant *> constantsRef = constants;
469 return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc);
470}
471
472/// Convert a dense resource elements attribute to an LLVM IR constant using its
473/// raw data storage if possible. This supports elements attributes of tensor or
474/// vector type and avoids constructing separate objects for individual values
475/// of the innermost dimension. Constants for other dimensions are still
476/// constructed recursively. Returns nullptr on failure and emits errors at
477/// `loc`.
478static llvm::Constant *convertDenseResourceElementsAttr(
479 Location loc, DenseResourceElementsAttr denseResourceAttr,
480 llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) {
481 assert(denseResourceAttr && "expected non-null attribute");
482
483 llvm::Type *innermostLLVMType = getInnermostElementType(type: llvmType);
484 if (!llvm::ConstantDataSequential::isElementTypeCompatible(
485 Ty: innermostLLVMType)) {
486 emitError(loc, message: "no known conversion for innermost element type");
487 return nullptr;
488 }
489
490 ShapedType type = denseResourceAttr.getType();
491 assert(type.getNumElements() > 0 && "Expected non-empty elements attribute");
492
493 AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob();
494 if (!blob) {
495 emitError(loc, message: "resource does not exist");
496 return nullptr;
497 }
498
499 ArrayRef<char> rawData = blob->getData();
500
501 // Check that the raw data size matches what is expected for the scalar size.
502 // TODO: in theory, we could repack the data here to keep constructing from
503 // raw data.
504 // TODO: we may also need to consider endianness when cross-compiling to an
505 // architecture where it is different.
506 int64_t numElements = denseResourceAttr.getType().getNumElements();
507 int64_t elementByteSize = rawData.size() / numElements;
508 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
509 emitError(loc, message: "raw data size does not match element type size");
510 return nullptr;
511 }
512
513 // Compute the shape of all dimensions but the innermost. Note that the
514 // innermost dimension may be that of the vector element type.
515 bool hasVectorElementType = isa<VectorType>(type.getElementType());
516 int64_t numAggregates =
517 numElements / (hasVectorElementType
518 ? 1
519 : denseResourceAttr.getType().getShape().back());
520 ArrayRef<int64_t> outerShape = type.getShape();
521 if (!hasVectorElementType)
522 outerShape = outerShape.drop_back();
523
524 // Create a constructor for the innermost constant from a piece of raw data.
525 std::function<llvm::Constant *(StringRef)> buildCstData;
526 if (isa<TensorType>(type)) {
527 auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
528 if (vectorElementType && vectorElementType.getRank() == 1) {
529 buildCstData = [&](StringRef data) {
530 return llvm::ConstantDataVector::getRaw(
531 data, vectorElementType.getShape().back(), innermostLLVMType);
532 };
533 } else if (!vectorElementType) {
534 buildCstData = [&](StringRef data) {
535 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
536 innermostLLVMType);
537 };
538 }
539 } else if (isa<VectorType>(type)) {
540 buildCstData = [&](StringRef data) {
541 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
542 innermostLLVMType);
543 };
544 }
545 if (!buildCstData) {
546 emitError(loc, message: "unsupported dense_resource type");
547 return nullptr;
548 }
549
550 // Create innermost constants and defer to the default constant creation
551 // mechanism for other dimensions.
552 SmallVector<llvm::Constant *> constants;
553 int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
554 (innermostLLVMType->getScalarSizeInBits() / 8);
555 constants.reserve(N: numAggregates);
556 for (unsigned i = 0; i < numAggregates; ++i) {
557 StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
558 constants.push_back(Elt: buildCstData(data));
559 }
560
561 ArrayRef<llvm::Constant *> constantsRef = constants;
562 return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc);
563}
564
565/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
566/// This currently supports integer, floating point, splat and dense element
567/// attributes and combinations thereof. Also, an array attribute with two
568/// elements is supported to represent a complex constant. In case of error,
569/// report it to `loc` and return nullptr.
570llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
571 llvm::Type *llvmType, Attribute attr, Location loc,
572 const ModuleTranslation &moduleTranslation) {
573 if (!attr || isa<UndefAttr>(attr))
574 return llvm::UndefValue::get(T: llvmType);
575 if (isa<ZeroAttr>(attr))
576 return llvm::Constant::getNullValue(Ty: llvmType);
577 if (auto *structType = dyn_cast<::llvm::StructType>(Val: llvmType)) {
578 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
579 if (!arrayAttr) {
580 emitError(loc, message: "expected an array attribute for a struct constant");
581 return nullptr;
582 }
583 SmallVector<llvm::Constant *> structElements;
584 structElements.reserve(N: structType->getNumElements());
585 for (auto [elemType, elemAttr] :
586 zip_equal(structType->elements(), arrayAttr)) {
587 llvm::Constant *element =
588 getLLVMConstant(elemType, elemAttr, loc, moduleTranslation);
589 if (!element)
590 return nullptr;
591 structElements.push_back(element);
592 }
593 return llvm::ConstantStruct::get(T: structType, V: structElements);
594 }
595 // For integer types, we allow a mismatch in sizes as the index type in
596 // MLIR might have a different size than the index type in the LLVM module.
597 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
598 return llvm::ConstantInt::get(
599 llvmType,
600 intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
601 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
602 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
603 // Special case for 8-bit floats, which are represented by integers due to
604 // the lack of native fp8 types in LLVM at the moment. Additionally, handle
605 // targets (like AMDGPU) that don't implement bfloat and convert all bfloats
606 // to i16.
607 unsigned floatWidth = APFloat::getSizeInBits(Sem: sem);
608 if (llvmType->isIntegerTy(Bitwidth: floatWidth))
609 return llvm::ConstantInt::get(llvmType,
610 floatAttr.getValue().bitcastToAPInt());
611 if (llvmType !=
612 llvm::Type::getFloatingPointTy(C&: llvmType->getContext(),
613 S: floatAttr.getValue().getSemantics())) {
614 emitError(loc, message: "FloatAttr does not match expected type of the constant");
615 return nullptr;
616 }
617 return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
618 }
619 if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr))
620 return llvm::ConstantExpr::getBitCast(
621 C: moduleTranslation.lookupFunction(name: funcAttr.getValue()), Ty: llvmType);
622 if (auto splatAttr = dyn_cast<SplatElementsAttr>(Val&: attr)) {
623 llvm::Type *elementType;
624 uint64_t numElements;
625 bool isScalable = false;
626 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: llvmType)) {
627 elementType = arrayTy->getElementType();
628 numElements = arrayTy->getNumElements();
629 } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(Val: llvmType)) {
630 elementType = fVectorTy->getElementType();
631 numElements = fVectorTy->getNumElements();
632 } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(Val: llvmType)) {
633 elementType = sVectorTy->getElementType();
634 numElements = sVectorTy->getMinNumElements();
635 isScalable = true;
636 } else {
637 llvm_unreachable("unrecognized constant vector type");
638 }
639 // Splat value is a scalar. Extract it only if the element type is not
640 // another sequence type. The recursion terminates because each step removes
641 // one outer sequential type.
642 bool elementTypeSequential =
643 isa<llvm::ArrayType, llvm::VectorType>(Val: elementType);
644 llvm::Constant *child = getLLVMConstant(
645 llvmType: elementType,
646 attr: elementTypeSequential ? splatAttr
647 : splatAttr.getSplatValue<Attribute>(),
648 loc, moduleTranslation);
649 if (!child)
650 return nullptr;
651 if (llvmType->isVectorTy())
652 return llvm::ConstantVector::getSplat(
653 EC: llvm::ElementCount::get(MinVal: numElements, /*Scalable=*/isScalable), Elt: child);
654 if (llvmType->isArrayTy()) {
655 auto *arrayType = llvm::ArrayType::get(ElementType: elementType, NumElements: numElements);
656 if (child->isZeroValue()) {
657 return llvm::ConstantAggregateZero::get(Ty: arrayType);
658 } else {
659 if (llvm::ConstantDataSequential::isElementTypeCompatible(
660 Ty: elementType)) {
661 // TODO: Handle all compatible types. This code only handles integer.
662 if (isa<llvm::IntegerType>(Val: elementType)) {
663 if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(Val: child)) {
664 if (ci->getBitWidth() == 8) {
665 SmallVector<int8_t> constants(numElements, ci->getZExtValue());
666 return llvm::ConstantDataArray::get(Context&: elementType->getContext(),
667 Elts&: constants);
668 }
669 if (ci->getBitWidth() == 16) {
670 SmallVector<int16_t> constants(numElements, ci->getZExtValue());
671 return llvm::ConstantDataArray::get(Context&: elementType->getContext(),
672 Elts&: constants);
673 }
674 if (ci->getBitWidth() == 32) {
675 SmallVector<int32_t> constants(numElements, ci->getZExtValue());
676 return llvm::ConstantDataArray::get(Context&: elementType->getContext(),
677 Elts&: constants);
678 }
679 if (ci->getBitWidth() == 64) {
680 SmallVector<int64_t> constants(numElements, ci->getZExtValue());
681 return llvm::ConstantDataArray::get(Context&: elementType->getContext(),
682 Elts&: constants);
683 }
684 }
685 }
686 }
687 // std::vector is used here to accomodate large number of elements that
688 // exceed SmallVector capacity.
689 std::vector<llvm::Constant *> constants(numElements, child);
690 return llvm::ConstantArray::get(T: arrayType, V: constants);
691 }
692 }
693 }
694
695 // Try using raw elements data if possible.
696 if (llvm::Constant *result =
697 convertDenseElementsAttr(loc, denseElementsAttr: dyn_cast<DenseElementsAttr>(Val&: attr),
698 llvmType, moduleTranslation)) {
699 return result;
700 }
701
702 if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) {
703 return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType,
704 moduleTranslation);
705 }
706
707 // Fall back to element-by-element construction otherwise.
708 if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
709 assert(elementsAttr.getShapedType().hasStaticShape());
710 assert(!elementsAttr.getShapedType().getShape().empty() &&
711 "unexpected empty elements attribute shape");
712
713 SmallVector<llvm::Constant *, 8> constants;
714 constants.reserve(N: elementsAttr.getNumElements());
715 llvm::Type *innermostType = getInnermostElementType(type: llvmType);
716 for (auto n : elementsAttr.getValues<Attribute>()) {
717 constants.push_back(
718 getLLVMConstant(innermostType, n, loc, moduleTranslation));
719 if (!constants.back())
720 return nullptr;
721 }
722 ArrayRef<llvm::Constant *> constantsRef = constants;
723 llvm::Constant *result = buildSequentialConstant(
724 constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
725 assert(constantsRef.empty() && "did not consume all elemental constants");
726 return result;
727 }
728
729 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
730 return llvm::ConstantDataArray::get(
731 Context&: moduleTranslation.getLLVMContext(),
732 Elts: ArrayRef<char>{stringAttr.getValue().data(),
733 stringAttr.getValue().size()});
734 }
735
736 // Handle arrays of structs that cannot be represented as DenseElementsAttr
737 // in MLIR.
738 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
739 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: llvmType)) {
740 llvm::Type *elementType = arrayTy->getElementType();
741 Attribute previousElementAttr;
742 llvm::Constant *elementCst = nullptr;
743 SmallVector<llvm::Constant *> constants;
744 constants.reserve(N: arrayTy->getNumElements());
745 for (Attribute elementAttr : arrayAttr) {
746 // Arrays with a single value or with repeating values are quite common.
747 // Short-circuit the translation when the element value is the same as
748 // the previous one.
749 if (!previousElementAttr || previousElementAttr != elementAttr) {
750 previousElementAttr = elementAttr;
751 elementCst =
752 getLLVMConstant(elementType, elementAttr, loc, moduleTranslation);
753 if (!elementCst)
754 return nullptr;
755 }
756 constants.push_back(elementCst);
757 }
758 return llvm::ConstantArray::get(T: arrayTy, V: constants);
759 }
760 }
761
762 emitError(loc, message: "unsupported constant value");
763 return nullptr;
764}
765
766ModuleTranslation::ModuleTranslation(Operation *module,
767 std::unique_ptr<llvm::Module> llvmModule)
768 : mlirModule(module), llvmModule(std::move(llvmModule)),
769 debugTranslation(
770 std::make_unique<DebugTranslation>(args&: module, args&: *this->llvmModule)),
771 loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
772 args&: *this, args&: *this->llvmModule)),
773 typeTranslator(this->llvmModule->getContext()),
774 iface(module->getContext()) {
775 assert(satisfiesLLVMModule(mlirModule) &&
776 "mlirModule should honor LLVM's module semantics.");
777}
778
779ModuleTranslation::~ModuleTranslation() {
780 if (ompBuilder)
781 ompBuilder->finalize();
782}
783
784void ModuleTranslation::forgetMapping(Region &region) {
785 SmallVector<Region *> toProcess;
786 toProcess.push_back(Elt: &region);
787 while (!toProcess.empty()) {
788 Region *current = toProcess.pop_back_val();
789 for (Block &block : *current) {
790 blockMapping.erase(Val: &block);
791 for (Value arg : block.getArguments())
792 valueMapping.erase(Val: arg);
793 for (Operation &op : block) {
794 for (Value value : op.getResults())
795 valueMapping.erase(Val: value);
796 if (op.hasSuccessors())
797 branchMapping.erase(Val: &op);
798 if (isa<LLVM::GlobalOp>(op))
799 globalsMapping.erase(Val: &op);
800 if (isa<LLVM::AliasOp>(op))
801 aliasesMapping.erase(Val: &op);
802 if (isa<LLVM::CallOp>(op))
803 callMapping.erase(Val: &op);
804 llvm::append_range(
805 C&: toProcess,
806 R: llvm::map_range(C: op.getRegions(), F: [](Region &r) { return &r; }));
807 }
808 }
809 }
810}
811
812/// Get the SSA value passed to the current block from the terminator operation
813/// of its predecessor.
814static Value getPHISourceValue(Block *current, Block *pred,
815 unsigned numArguments, unsigned index) {
816 Operation &terminator = *pred->getTerminator();
817 if (isa<LLVM::BrOp>(terminator))
818 return terminator.getOperand(idx: index);
819
820#ifndef NDEBUG
821 llvm::SmallPtrSet<Block *, 4> seenSuccessors;
822 for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
823 Block *successor = terminator.getSuccessor(index: i);
824 auto branch = cast<BranchOpInterface>(terminator);
825 SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
826 assert(
827 (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
828 "successors with arguments in LLVM branches must be different blocks");
829 seenSuccessors.insert(Ptr: successor);
830 }
831#endif
832
833 // For instructions that branch based on a condition value, we need to take
834 // the operands for the branch that was taken.
835 if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
836 // For conditional branches, we take the operands from either the "true" or
837 // the "false" branch.
838 return condBranchOp.getSuccessor(0) == current
839 ? condBranchOp.getTrueDestOperands()[index]
840 : condBranchOp.getFalseDestOperands()[index];
841 }
842
843 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
844 // For switches, we take the operands from either the default case, or from
845 // the case branch that was taken.
846 if (switchOp.getDefaultDestination() == current)
847 return switchOp.getDefaultOperands()[index];
848 for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations()))
849 if (i.value() == current)
850 return switchOp.getCaseOperands(i.index())[index];
851 }
852
853 if (auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(terminator)) {
854 // For indirect branches we take operands for each successor.
855 for (const auto &i : llvm::enumerate(indBrOp->getSuccessors())) {
856 if (indBrOp->getSuccessor(i.index()) == current)
857 return indBrOp.getSuccessorOperands(i.index())[index];
858 }
859 }
860
861 if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
862 return invokeOp.getNormalDest() == current
863 ? invokeOp.getNormalDestOperands()[index]
864 : invokeOp.getUnwindDestOperands()[index];
865 }
866
867 llvm_unreachable(
868 "only branch, switch or invoke operations can be terminators "
869 "of a block that has successors");
870}
871
872/// Connect the PHI nodes to the results of preceding blocks.
873void mlir::LLVM::detail::connectPHINodes(Region &region,
874 const ModuleTranslation &state) {
875 // Skip the first block, it cannot be branched to and its arguments correspond
876 // to the arguments of the LLVM function.
877 for (Block &bb : llvm::drop_begin(RangeOrContainer&: region)) {
878 llvm::BasicBlock *llvmBB = state.lookupBlock(block: &bb);
879 auto phis = llvmBB->phis();
880 auto numArguments = bb.getNumArguments();
881 assert(numArguments == std::distance(phis.begin(), phis.end()));
882 for (auto [index, phiNode] : llvm::enumerate(First&: phis)) {
883 for (auto *pred : bb.getPredecessors()) {
884 // Find the LLVM IR block that contains the converted terminator
885 // instruction and use it in the PHI node. Note that this block is not
886 // necessarily the same as state.lookupBlock(pred), some operations
887 // (in particular, OpenMP operations using OpenMPIRBuilder) may have
888 // split the blocks.
889 llvm::Instruction *terminator =
890 state.lookupBranch(op: pred->getTerminator());
891 assert(terminator && "missing the mapping for a terminator");
892 phiNode.addIncoming(V: state.lookupValue(value: getPHISourceValue(
893 current: &bb, pred, numArguments, index)),
894 BB: terminator->getParent());
895 }
896 }
897 }
898}
899
900llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
901 llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
902 ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
903 llvm::Module *module = builder.GetInsertBlock()->getModule();
904 llvm::Function *fn =
905 llvm::Intrinsic::getOrInsertDeclaration(M: module, id: intrinsic, Tys: tys);
906 return builder.CreateCall(Callee: fn, Args: args);
907}
908
909llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
910 llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
911 Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
912 ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
913 ArrayRef<unsigned> immArgPositions,
914 ArrayRef<StringLiteral> immArgAttrNames) {
915 assert(immArgPositions.size() == immArgAttrNames.size() &&
916 "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
917 "length");
918
919 SmallVector<llvm::OperandBundleDef> opBundles;
920 size_t numOpBundleOperands = 0;
921 auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
922 intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
923 auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
924 intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
925
926 if (opBundleSizesAttr && opBundleTagsAttr) {
927 ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
928 assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
929 "operand bundles and tags do not match");
930
931 numOpBundleOperands =
932 std::accumulate(first: opBundleSizes.begin(), last: opBundleSizes.end(), init: size_t(0));
933 assert(numOpBundleOperands <= intrOp->getNumOperands() &&
934 "operand bundle operands is more than the number of operands");
935
936 ValueRange operands = intrOp->getOperands().take_back(n: numOpBundleOperands);
937 size_t nextOperandIdx = 0;
938 opBundles.reserve(N: opBundleSizesAttr.size());
939
940 for (auto [opBundleTagAttr, bundleSize] :
941 llvm::zip(opBundleTagsAttr, opBundleSizes)) {
942 auto bundleTag = cast<StringAttr>(opBundleTagAttr).str();
943 auto bundleOperands = moduleTranslation.lookupValues(
944 operands.slice(nextOperandIdx, bundleSize));
945 opBundles.emplace_back(std::move(bundleTag), std::move(bundleOperands));
946 nextOperandIdx += bundleSize;
947 }
948 }
949
950 // Map operands and attributes to LLVM values.
951 auto opOperands = intrOp->getOperands().drop_back(n: numOpBundleOperands);
952 auto operands = moduleTranslation.lookupValues(values: opOperands);
953 SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
954 for (auto [immArgPos, immArgName] :
955 llvm::zip(t&: immArgPositions, u&: immArgAttrNames)) {
956 auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
957 assert(attr.getType().isIntOrFloat() && "expected int or float immarg");
958 auto *type = moduleTranslation.convertType(type: attr.getType());
959 args[immArgPos] = LLVM::detail::getLLVMConstant(
960 llvmType: type, attr: attr, loc: intrOp->getLoc(), moduleTranslation);
961 }
962 unsigned opArg = 0;
963 for (auto &arg : args) {
964 if (!arg)
965 arg = operands[opArg++];
966 }
967
968 // Resolve overloaded intrinsic declaration.
969 SmallVector<llvm::Type *> overloadedTypes;
970 for (unsigned overloadedResultIdx : overloadedResults) {
971 if (numResults > 1) {
972 // More than one result is mapped to an LLVM struct.
973 overloadedTypes.push_back(moduleTranslation.convertType(
974 llvm::cast<LLVM::LLVMStructType>(intrOp->getResult(0).getType())
975 .getBody()[overloadedResultIdx]));
976 } else {
977 overloadedTypes.push_back(
978 Elt: moduleTranslation.convertType(type: intrOp->getResult(idx: 0).getType()));
979 }
980 }
981 for (unsigned overloadedOperandIdx : overloadedOperands)
982 overloadedTypes.push_back(Elt: args[overloadedOperandIdx]->getType());
983 llvm::Module *module = builder.GetInsertBlock()->getModule();
984 llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
985 M: module, id: intrinsic, Tys: overloadedTypes);
986
987 return builder.CreateCall(Callee: llvmIntr, Args: args, OpBundles: opBundles);
988}
989
990/// Given a single MLIR operation, create the corresponding LLVM IR operation
991/// using the `builder`.
992LogicalResult ModuleTranslation::convertOperation(Operation &op,
993 llvm::IRBuilderBase &builder,
994 bool recordInsertions) {
995 const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(obj: &op);
996 if (!opIface)
997 return op.emitError(message: "cannot be converted to LLVM IR: missing "
998 "`LLVMTranslationDialectInterface` registration for "
999 "dialect for op: ")
1000 << op.getName();
1001
1002 InstructionCapturingInserter::CollectionScope scope(builder,
1003 recordInsertions);
1004 if (failed(Result: opIface->convertOperation(op: &op, builder, moduleTranslation&: *this)))
1005 return op.emitError(message: "LLVM Translation failed for operation: ")
1006 << op.getName();
1007
1008 return convertDialectAttributes(op: &op, instructions: scope.getCapturedInstructions());
1009}
1010
1011/// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
1012/// to define values corresponding to the MLIR block arguments. These nodes
1013/// are not connected to the source basic blocks, which may not exist yet. Uses
1014/// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
1015/// been created for `bb` and included in the block mapping. Inserts new
1016/// instructions at the end of the block and leaves `builder` in a state
1017/// suitable for further insertion into the end of the block.
1018LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
1019 bool ignoreArguments,
1020 llvm::IRBuilderBase &builder,
1021 bool recordInsertions) {
1022 builder.SetInsertPoint(lookupBlock(block: &bb));
1023 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
1024
1025 // Before traversing operations, make block arguments available through
1026 // value remapping and PHI nodes, but do not add incoming edges for the PHI
1027 // nodes just yet: those values may be defined by this or following blocks.
1028 // This step is omitted if "ignoreArguments" is set. The arguments of the
1029 // first block have been already made available through the remapping of
1030 // LLVM function arguments.
1031 if (!ignoreArguments) {
1032 auto predecessors = bb.getPredecessors();
1033 unsigned numPredecessors =
1034 std::distance(first: predecessors.begin(), last: predecessors.end());
1035 for (auto arg : bb.getArguments()) {
1036 auto wrappedType = arg.getType();
1037 if (!isCompatibleType(type: wrappedType))
1038 return emitError(loc: bb.front().getLoc(),
1039 message: "block argument does not have an LLVM type");
1040 builder.SetCurrentDebugLocation(
1041 debugTranslation->translateLoc(loc: arg.getLoc(), scope: subprogram));
1042 llvm::Type *type = convertType(type: wrappedType);
1043 llvm::PHINode *phi = builder.CreatePHI(Ty: type, NumReservedValues: numPredecessors);
1044 mapValue(mlir: arg, llvm: phi);
1045 }
1046 }
1047
1048 // Traverse operations.
1049 for (auto &op : bb) {
1050 // Set the current debug location within the builder.
1051 builder.SetCurrentDebugLocation(
1052 debugTranslation->translateLoc(loc: op.getLoc(), scope: subprogram));
1053
1054 if (failed(Result: convertOperation(op, builder, recordInsertions)))
1055 return failure();
1056
1057 // Set the branch weight metadata on the translated instruction.
1058 if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
1059 setBranchWeightsMetadata(iface);
1060 }
1061
1062 return success();
1063}
1064
1065/// A helper method to get the single Block in an operation honoring LLVM's
1066/// module requirements.
1067static Block &getModuleBody(Operation *module) {
1068 return module->getRegion(index: 0).front();
1069}
1070
1071/// A helper method to decide if a constant must not be set as a global variable
1072/// initializer. For an external linkage variable, the variable with an
1073/// initializer is considered externally visible and defined in this module, the
1074/// variable without an initializer is externally available and is defined
1075/// elsewhere.
1076static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,
1077 llvm::Constant *cst) {
1078 return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
1079 linkage == llvm::GlobalVariable::ExternalWeakLinkage;
1080}
1081
1082/// Sets the runtime preemption specifier of `gv` to dso_local if
1083/// `dsoLocalRequested` is true, otherwise it is left unchanged.
1084static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
1085 llvm::GlobalValue *gv) {
1086 if (dsoLocalRequested)
1087 gv->setDSOLocal(true);
1088}
1089
1090LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
1091 // Mapping from compile unit to its respective set of global variables.
1092 DenseMap<llvm::DICompileUnit *, SmallVector<llvm::Metadata *>> allGVars;
1093
1094 // First, create all global variables and global aliases in LLVM IR. A global
1095 // or alias body may refer to another global/alias or itself, so all the
1096 // mapping needs to happen prior to body conversion.
1097
1098 // Create all llvm::GlobalVariable
1099 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
1100 llvm::Type *type = convertType(op.getType());
1101 llvm::Constant *cst = nullptr;
1102 if (op.getValueOrNull()) {
1103 // String attributes are treated separately because they cannot appear as
1104 // in-function constants and are thus not supported by getLLVMConstant.
1105 if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) {
1106 cst = llvm::ConstantDataArray::getString(
1107 llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
1108 type = cst->getType();
1109 } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(),
1110 *this))) {
1111 return failure();
1112 }
1113 }
1114
1115 auto linkage = convertLinkageToLLVM(op.getLinkage());
1116
1117 // LLVM IR requires constant with linkage other than external or weak
1118 // external to have initializers. If MLIR does not provide an initializer,
1119 // default to undef.
1120 bool dropInitializer = shouldDropGlobalInitializer(linkage, cst);
1121 if (!dropInitializer && !cst)
1122 cst = llvm::UndefValue::get(type);
1123 else if (dropInitializer && cst)
1124 cst = nullptr;
1125
1126 auto *var = new llvm::GlobalVariable(
1127 *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
1128 /*InsertBefore=*/nullptr,
1129 op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
1130 : llvm::GlobalValue::NotThreadLocal,
1131 op.getAddrSpace(), op.getExternallyInitialized());
1132
1133 if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
1134 auto selectorOp = cast<ComdatSelectorOp>(
1135 SymbolTable::lookupNearestSymbolFrom(op, *comdat));
1136 var->setComdat(comdatMapping.lookup(selectorOp));
1137 }
1138
1139 if (op.getUnnamedAddr().has_value())
1140 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
1141
1142 if (op.getSection().has_value())
1143 var->setSection(*op.getSection());
1144
1145 addRuntimePreemptionSpecifier(op.getDsoLocal(), var);
1146
1147 std::optional<uint64_t> alignment = op.getAlignment();
1148 if (alignment.has_value())
1149 var->setAlignment(llvm::MaybeAlign(alignment.value()));
1150
1151 var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
1152
1153 globalsMapping.try_emplace(op, var);
1154
1155 // Add debug information if present.
1156 if (op.getDbgExprs()) {
1157 for (auto exprAttr :
1158 op.getDbgExprs()->getAsRange<DIGlobalVariableExpressionAttr>()) {
1159 llvm::DIGlobalVariableExpression *diGlobalExpr =
1160 debugTranslation->translateGlobalVariableExpression(exprAttr);
1161 llvm::DIGlobalVariable *diGlobalVar = diGlobalExpr->getVariable();
1162 var->addDebugInfo(diGlobalExpr);
1163
1164 // There is no `globals` field in DICompileUnitAttr which can be
1165 // directly assigned to DICompileUnit. We have to build the list by
1166 // looking at the dbgExpr of all the GlobalOps. The scope of the
1167 // variable is used to get the DICompileUnit in which to add it. But
1168 // there are cases where the scope of a global does not directly point
1169 // to the DICompileUnit and we have to do a bit more work to get to
1170 // it. Some of those cases are:
1171 //
1172 // 1. For the languages that support modules, the scope hierarchy can
1173 // be variable -> DIModule -> DICompileUnit
1174 //
1175 // 2. For the Fortran common block variable, the scope hierarchy can
1176 // be variable -> DICommonBlock -> DISubprogram -> DICompileUnit
1177 //
1178 // 3. For entities like static local variables in C or variable with
1179 // SAVE attribute in Fortran, the scope hierarchy can be
1180 // variable -> DISubprogram -> DICompileUnit
1181 llvm::DIScope *scope = diGlobalVar->getScope();
1182 if (auto *mod = dyn_cast_if_present<llvm::DIModule>(scope))
1183 scope = mod->getScope();
1184 else if (auto *cb = dyn_cast_if_present<llvm::DICommonBlock>(scope)) {
1185 if (auto *sp =
1186 dyn_cast_if_present<llvm::DISubprogram>(cb->getScope()))
1187 scope = sp->getUnit();
1188 } else if (auto *sp = dyn_cast_if_present<llvm::DISubprogram>(scope))
1189 scope = sp->getUnit();
1190
1191 // Get the compile unit (scope) of the the global variable.
1192 if (llvm::DICompileUnit *compileUnit =
1193 dyn_cast_if_present<llvm::DICompileUnit>(scope)) {
1194 // Update the compile unit with this incoming global variable
1195 // expression during the finalizing step later.
1196 allGVars[compileUnit].push_back(diGlobalExpr);
1197 }
1198 }
1199 }
1200 }
1201
1202 // Create all llvm::GlobalAlias
1203 for (auto op : getModuleBody(mlirModule).getOps<LLVM::AliasOp>()) {
1204 llvm::Type *type = convertType(op.getType());
1205 llvm::Constant *cst = nullptr;
1206 llvm::GlobalValue::LinkageTypes linkage =
1207 convertLinkageToLLVM(op.getLinkage());
1208 llvm::Module &llvmMod = *llvmModule;
1209
1210 // Note address space and aliasee info isn't set just yet.
1211 llvm::GlobalAlias *var = llvm::GlobalAlias::create(
1212 type, op.getAddrSpace(), linkage, op.getSymName(), /*placeholder*/ cst,
1213 &llvmMod);
1214
1215 var->setThreadLocalMode(op.getThreadLocal_()
1216 ? llvm::GlobalAlias::GeneralDynamicTLSModel
1217 : llvm::GlobalAlias::NotThreadLocal);
1218
1219 // Note there is no need to setup the comdat because GlobalAlias calls into
1220 // the aliasee comdat information automatically.
1221
1222 if (op.getUnnamedAddr().has_value())
1223 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
1224
1225 var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
1226
1227 aliasesMapping.try_emplace(op, var);
1228 }
1229
1230 // Convert global variable bodies.
1231 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
1232 if (Block *initializer = op.getInitializerBlock()) {
1233 llvm::IRBuilder<llvm::TargetFolder> builder(
1234 llvmModule->getContext(),
1235 llvm::TargetFolder(llvmModule->getDataLayout()));
1236
1237 [[maybe_unused]] int numConstantsHit = 0;
1238 [[maybe_unused]] int numConstantsErased = 0;
1239 DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;
1240
1241 for (auto &op : initializer->without_terminator()) {
1242 if (failed(convertOperation(op, builder)))
1243 return emitError(op.getLoc(), "fail to convert global initializer");
1244 auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0)));
1245 if (!cst)
1246 return emitError(op.getLoc(), "unemittable constant value");
1247
1248 // When emitting an LLVM constant, a new constant is created and the old
1249 // constant may become dangling and take space. We should remove the
1250 // dangling constants to avoid memory explosion especially for constant
1251 // arrays whose number of elements is large.
1252 // Because multiple operations may refer to the same constant, we need
1253 // to count the number of uses of each constant array and remove it only
1254 // when the count becomes zero.
1255 if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
1256 numConstantsHit++;
1257 Value result = op.getResult(0);
1258 int numUsers = std::distance(result.use_begin(), result.use_end());
1259 auto [iterator, inserted] =
1260 constantAggregateUseMap.try_emplace(agg, numUsers);
1261 if (!inserted) {
1262 // Key already exists, update the value
1263 iterator->second += numUsers;
1264 }
1265 }
1266 // Scan the operands of the operation to decrement the use count of
1267 // constants. Erase the constant if the use count becomes zero.
1268 for (Value v : op.getOperands()) {
1269 auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
1270 if (!cst)
1271 continue;
1272 auto iter = constantAggregateUseMap.find(cst);
1273 assert(iter != constantAggregateUseMap.end() && "constant not found");
1274 iter->second--;
1275 if (iter->second == 0) {
1276 // NOTE: cannot call removeDeadConstantUsers() here because it
1277 // may remove the constant which has uses not be converted yet.
1278 if (cst->user_empty()) {
1279 cst->destroyConstant();
1280 numConstantsErased++;
1281 }
1282 constantAggregateUseMap.erase(iter);
1283 }
1284 }
1285 }
1286
1287 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
1288 llvm::Constant *cst =
1289 cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
1290 auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
1291 if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
1292 global->setInitializer(cst);
1293
1294 // Try to remove the dangling constants again after all operations are
1295 // converted.
1296 for (auto it : constantAggregateUseMap) {
1297 auto cst = it.first;
1298 cst->removeDeadConstantUsers();
1299 if (cst->user_empty()) {
1300 cst->destroyConstant();
1301 numConstantsErased++;
1302 }
1303 }
1304
1305 LLVM_DEBUG(llvm::dbgs()
1306 << "Convert initializer for " << op.getName() << "\n";
1307 llvm::dbgs() << numConstantsHit << " new constants hit\n";
1308 llvm::dbgs()
1309 << numConstantsErased << " dangling constants erased\n";);
1310 }
1311 }
1312
1313 // Convert llvm.mlir.global_ctors and dtors.
1314 for (Operation &op : getModuleBody(module: mlirModule)) {
1315 auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
1316 auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
1317 if (!ctorOp && !dtorOp)
1318 continue;
1319
1320 // The empty / zero initialized version of llvm.global_(c|d)tors cannot be
1321 // handled by appendGlobalFn logic below, which just ignores empty (c|d)tor
1322 // lists. Make sure it gets emitted.
1323 if ((ctorOp && ctorOp.getCtors().empty()) ||
1324 (dtorOp && dtorOp.getDtors().empty())) {
1325 llvm::IRBuilder<llvm::TargetFolder> builder(
1326 llvmModule->getContext(),
1327 llvm::TargetFolder(llvmModule->getDataLayout()));
1328 llvm::Type *eltTy = llvm::StructType::get(
1329 elt1: builder.getInt32Ty(), elts: builder.getPtrTy(), elts: builder.getPtrTy());
1330 llvm::ArrayType *at = llvm::ArrayType::get(ElementType: eltTy, NumElements: 0);
1331 llvm::Constant *zeroInit = llvm::Constant::getNullValue(Ty: at);
1332 (void)new llvm::GlobalVariable(
1333 *llvmModule, zeroInit->getType(), false,
1334 llvm::GlobalValue::AppendingLinkage, zeroInit,
1335 ctorOp ? "llvm.global_ctors" : "llvm.global_dtors");
1336 } else {
1337 auto range = ctorOp
1338 ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
1339 : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
1340 auto appendGlobalFn =
1341 ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
1342 for (const auto &[sym, prio] : range) {
1343 llvm::Function *f =
1344 lookupFunction(cast<FlatSymbolRefAttr>(sym).getValue());
1345 appendGlobalFn(*llvmModule, f, cast<IntegerAttr>(prio).getInt(),
1346 /*Data=*/nullptr);
1347 }
1348 }
1349 }
1350
1351 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
1352 if (failed(convertDialectAttributes(op, {})))
1353 return failure();
1354
1355 // Finally, update the compile units their respective sets of global variables
1356 // created earlier.
1357 for (const auto &[compileUnit, globals] : allGVars) {
1358 compileUnit->replaceGlobalVariables(
1359 N: llvm::MDTuple::get(Context&: getLLVMContext(), MDs: globals));
1360 }
1361
1362 // Convert global alias bodies.
1363 for (auto op : getModuleBody(mlirModule).getOps<LLVM::AliasOp>()) {
1364 Block &initializer = op.getInitializerBlock();
1365 llvm::IRBuilder<llvm::TargetFolder> builder(
1366 llvmModule->getContext(),
1367 llvm::TargetFolder(llvmModule->getDataLayout()));
1368
1369 for (mlir::Operation &op : initializer.without_terminator()) {
1370 if (failed(convertOperation(op, builder)))
1371 return emitError(op.getLoc(), "fail to convert alias initializer");
1372 if (!isa<llvm::Constant>(lookupValue(op.getResult(0))))
1373 return emitError(op.getLoc(), "unemittable constant value");
1374 }
1375
1376 auto ret = cast<ReturnOp>(initializer.getTerminator());
1377 auto *cst = cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
1378 assert(aliasesMapping.count(op));
1379 auto *alias = cast<llvm::GlobalAlias>(aliasesMapping[op]);
1380 alias->setAliasee(cst);
1381 }
1382
1383 for (auto op : getModuleBody(mlirModule).getOps<LLVM::AliasOp>())
1384 if (failed(convertDialectAttributes(op, {})))
1385 return failure();
1386
1387 return success();
1388}
1389
1390/// Attempts to add an attribute identified by `key`, optionally with the given
1391/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
1392/// attribute has a kind known to LLVM IR, create the attribute of this kind,
1393/// otherwise keep it as a string attribute. Performs additional checks for
1394/// attributes known to have or not have a value in order to avoid assertions
1395/// inside LLVM upon construction.
1396static LogicalResult checkedAddLLVMFnAttribute(Location loc,
1397 llvm::Function *llvmFunc,
1398 StringRef key,
1399 StringRef value = StringRef()) {
1400 auto kind = llvm::Attribute::getAttrKindFromName(AttrName: key);
1401 if (kind == llvm::Attribute::None) {
1402 llvmFunc->addFnAttr(Kind: key, Val: value);
1403 return success();
1404 }
1405
1406 if (llvm::Attribute::isIntAttrKind(Kind: kind)) {
1407 if (value.empty())
1408 return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
1409
1410 int64_t result;
1411 if (!value.getAsInteger(/*Radix=*/0, Result&: result))
1412 llvmFunc->addFnAttr(
1413 Attr: llvm::Attribute::get(Context&: llvmFunc->getContext(), Kind: kind, Val: result));
1414 else
1415 llvmFunc->addFnAttr(Kind: key, Val: value);
1416 return success();
1417 }
1418
1419 if (!value.empty())
1420 return emitError(loc) << "LLVM attribute '" << key
1421 << "' does not expect a value, found '" << value
1422 << "'";
1423
1424 llvmFunc->addFnAttr(Kind: kind);
1425 return success();
1426}
1427
1428/// Return a representation of `value` as metadata.
1429static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
1430 const llvm::APInt &value) {
1431 llvm::Constant *constant = llvm::ConstantInt::get(Context&: context, V: value);
1432 return llvm::ConstantAsMetadata::get(C: constant);
1433}
1434
1435/// Return a representation of `value` as an MDNode.
1436static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
1437 const llvm::APInt &value) {
1438 return llvm::MDNode::get(Context&: context, MDs: convertIntegerToMetadata(context, value));
1439}
1440
1441/// Return an MDNode encoding `vec_type_hint` metadata.
1442static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
1443 llvm::Type *type,
1444 bool isSigned) {
1445 llvm::Metadata *typeMD =
1446 llvm::ConstantAsMetadata::get(C: llvm::UndefValue::get(T: type));
1447 llvm::Metadata *isSignedMD =
1448 convertIntegerToMetadata(context, value: llvm::APInt(32, isSigned ? 1 : 0));
1449 return llvm::MDNode::get(Context&: context, MDs: {typeMD, isSignedMD});
1450}
1451
1452/// Return an MDNode with a tuple given by the values in `values`.
1453static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
1454 ArrayRef<int32_t> values) {
1455 SmallVector<llvm::Metadata *> mdValues;
1456 llvm::transform(
1457 Range&: values, d_first: std::back_inserter(x&: mdValues), F: [&context](int32_t value) {
1458 return convertIntegerToMetadata(context, value: llvm::APInt(32, value));
1459 });
1460 return llvm::MDNode::get(Context&: context, MDs: mdValues);
1461}
1462
1463/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
1464/// Reports error to `loc` if any and returns immediately. Expects `attributes`
1465/// to be an array attribute containing either string attributes, treated as
1466/// value-less LLVM attributes, or array attributes containing two string
1467/// attributes, with the first string being the name of the corresponding LLVM
1468/// attribute and the second string beings its value. Note that even integer
1469/// attributes are expected to have their values expressed as strings.
1470static LogicalResult
1471forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
1472 llvm::Function *llvmFunc) {
1473 if (!attributes)
1474 return success();
1475
1476 for (Attribute attr : *attributes) {
1477 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1478 if (failed(
1479 checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
1480 return failure();
1481 continue;
1482 }
1483
1484 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
1485 if (!arrayAttr || arrayAttr.size() != 2)
1486 return emitError(loc)
1487 << "expected 'passthrough' to contain string or array attributes";
1488
1489 auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
1490 auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
1491 if (!keyAttr || !valueAttr)
1492 return emitError(loc)
1493 << "expected arrays within 'passthrough' to contain two strings";
1494
1495 if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
1496 valueAttr.getValue())))
1497 return failure();
1498 }
1499 return success();
1500}
1501
1502LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
1503 // Clear the block, branch value mappings, they are only relevant within one
1504 // function.
1505 blockMapping.clear();
1506 valueMapping.clear();
1507 branchMapping.clear();
1508 llvm::Function *llvmFunc = lookupFunction(name: func.getName());
1509
1510 // Add function arguments to the value remapping table.
1511 for (auto [mlirArg, llvmArg] :
1512 llvm::zip(func.getArguments(), llvmFunc->args()))
1513 mapValue(mlirArg, &llvmArg);
1514
1515 // Check the personality and set it.
1516 if (func.getPersonality()) {
1517 llvm::Type *ty = llvm::PointerType::getUnqual(C&: llvmFunc->getContext());
1518 if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(),
1519 func.getLoc(), *this))
1520 llvmFunc->setPersonalityFn(pfunc);
1521 }
1522
1523 if (std::optional<StringRef> section = func.getSection())
1524 llvmFunc->setSection(*section);
1525
1526 if (func.getArmStreaming())
1527 llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_enabled");
1528 else if (func.getArmLocallyStreaming())
1529 llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_body");
1530 else if (func.getArmStreamingCompatible())
1531 llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_compatible");
1532
1533 if (func.getArmNewZa())
1534 llvmFunc->addFnAttr(Kind: "aarch64_new_za");
1535 else if (func.getArmInZa())
1536 llvmFunc->addFnAttr(Kind: "aarch64_in_za");
1537 else if (func.getArmOutZa())
1538 llvmFunc->addFnAttr(Kind: "aarch64_out_za");
1539 else if (func.getArmInoutZa())
1540 llvmFunc->addFnAttr(Kind: "aarch64_inout_za");
1541 else if (func.getArmPreservesZa())
1542 llvmFunc->addFnAttr(Kind: "aarch64_preserves_za");
1543
1544 if (auto targetCpu = func.getTargetCpu())
1545 llvmFunc->addFnAttr("target-cpu", *targetCpu);
1546
1547 if (auto tuneCpu = func.getTuneCpu())
1548 llvmFunc->addFnAttr("tune-cpu", *tuneCpu);
1549
1550 if (auto reciprocalEstimates = func.getReciprocalEstimates())
1551 llvmFunc->addFnAttr("reciprocal-estimates", *reciprocalEstimates);
1552
1553 if (auto preferVectorWidth = func.getPreferVectorWidth())
1554 llvmFunc->addFnAttr("prefer-vector-width", *preferVectorWidth);
1555
1556 if (auto attr = func.getVscaleRange())
1557 llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
1558 Context&: getLLVMContext(), MinValue: attr->getMinRange().getInt(),
1559 MaxValue: attr->getMaxRange().getInt()));
1560
1561 if (auto unsafeFpMath = func.getUnsafeFpMath())
1562 llvmFunc->addFnAttr("unsafe-fp-math", llvm::toStringRef(*unsafeFpMath));
1563
1564 if (auto noInfsFpMath = func.getNoInfsFpMath())
1565 llvmFunc->addFnAttr("no-infs-fp-math", llvm::toStringRef(*noInfsFpMath));
1566
1567 if (auto noNansFpMath = func.getNoNansFpMath())
1568 llvmFunc->addFnAttr("no-nans-fp-math", llvm::toStringRef(*noNansFpMath));
1569
1570 if (auto approxFuncFpMath = func.getApproxFuncFpMath())
1571 llvmFunc->addFnAttr("approx-func-fp-math",
1572 llvm::toStringRef(*approxFuncFpMath));
1573
1574 if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath())
1575 llvmFunc->addFnAttr("no-signed-zeros-fp-math",
1576 llvm::toStringRef(*noSignedZerosFpMath));
1577
1578 if (auto denormalFpMath = func.getDenormalFpMath())
1579 llvmFunc->addFnAttr("denormal-fp-math", *denormalFpMath);
1580
1581 if (auto denormalFpMathF32 = func.getDenormalFpMathF32())
1582 llvmFunc->addFnAttr("denormal-fp-math-f32", *denormalFpMathF32);
1583
1584 if (auto fpContract = func.getFpContract())
1585 llvmFunc->addFnAttr("fp-contract", *fpContract);
1586
1587 if (auto instrumentFunctionEntry = func.getInstrumentFunctionEntry())
1588 llvmFunc->addFnAttr("instrument-function-entry", *instrumentFunctionEntry);
1589
1590 if (auto instrumentFunctionExit = func.getInstrumentFunctionExit())
1591 llvmFunc->addFnAttr("instrument-function-exit", *instrumentFunctionExit);
1592
1593 // First, create all blocks so we can jump to them.
1594 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1595 for (auto &bb : func) {
1596 auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
1597 llvmBB->insertInto(llvmFunc);
1598 mapBlock(&bb, llvmBB);
1599 }
1600
1601 // Then, convert blocks one by one in topological order to ensure defs are
1602 // converted before uses.
1603 auto blocks = getBlocksSortedByDominance(func.getBody());
1604 for (Block *bb : blocks) {
1605 CapturingIRBuilder builder(llvmContext,
1606 llvm::TargetFolder(llvmModule->getDataLayout()));
1607 if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder,
1608 /*recordInsertions=*/true)))
1609 return failure();
1610 }
1611
1612 // After all blocks have been traversed and values mapped, connect the PHI
1613 // nodes to the results of preceding blocks.
1614 detail::connectPHINodes(region&: func.getBody(), state: *this);
1615
1616 // Finally, convert dialect attributes attached to the function.
1617 return convertDialectAttributes(op: func, instructions: {});
1618}
1619
1620LogicalResult ModuleTranslation::convertDialectAttributes(
1621 Operation *op, ArrayRef<llvm::Instruction *> instructions) {
1622 for (NamedAttribute attribute : op->getDialectAttrs())
1623 if (failed(Result: iface.amendOperation(op, instructions, attribute, moduleTranslation&: *this)))
1624 return failure();
1625 return success();
1626}
1627
1628/// Converts memory effect attributes from `func` and attaches them to
1629/// `llvmFunc`.
1630static void convertFunctionMemoryAttributes(LLVMFuncOp func,
1631 llvm::Function *llvmFunc) {
1632 if (!func.getMemoryEffects())
1633 return;
1634
1635 MemoryEffectsAttr memEffects = func.getMemoryEffectsAttr();
1636
1637 // Add memory effects incrementally.
1638 llvm::MemoryEffects newMemEffects =
1639 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
1640 convertModRefInfoToLLVM(memEffects.getArgMem()));
1641 newMemEffects |= llvm::MemoryEffects(
1642 llvm::MemoryEffects::Location::InaccessibleMem,
1643 convertModRefInfoToLLVM(memEffects.getInaccessibleMem()));
1644 newMemEffects |=
1645 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
1646 convertModRefInfoToLLVM(memEffects.getOther()));
1647 llvmFunc->setMemoryEffects(newMemEffects);
1648}
1649
1650/// Converts function attributes from `func` and attaches them to `llvmFunc`.
1651static void convertFunctionAttributes(LLVMFuncOp func,
1652 llvm::Function *llvmFunc) {
1653 if (func.getNoInlineAttr())
1654 llvmFunc->addFnAttr(llvm::Attribute::NoInline);
1655 if (func.getAlwaysInlineAttr())
1656 llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline);
1657 if (func.getOptimizeNoneAttr())
1658 llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone);
1659 if (func.getConvergentAttr())
1660 llvmFunc->addFnAttr(llvm::Attribute::Convergent);
1661 if (func.getNoUnwindAttr())
1662 llvmFunc->addFnAttr(llvm::Attribute::NoUnwind);
1663 if (func.getWillReturnAttr())
1664 llvmFunc->addFnAttr(llvm::Attribute::WillReturn);
1665 if (TargetFeaturesAttr targetFeatAttr = func.getTargetFeaturesAttr())
1666 llvmFunc->addFnAttr("target-features", targetFeatAttr.getFeaturesString());
1667 if (FramePointerKindAttr fpAttr = func.getFramePointerAttr())
1668 llvmFunc->addFnAttr("frame-pointer", stringifyFramePointerKind(
1669 fpAttr.getFramePointerKind()));
1670 if (UWTableKindAttr uwTableKindAttr = func.getUwtableKindAttr())
1671 llvmFunc->setUWTableKind(
1672 convertUWTableKindToLLVM(uwTableKindAttr.getUwtableKind()));
1673 convertFunctionMemoryAttributes(func, llvmFunc);
1674}
1675
1676/// Converts function attributes from `func` and attaches them to `llvmFunc`.
1677static void convertFunctionKernelAttributes(LLVMFuncOp func,
1678 llvm::Function *llvmFunc,
1679 ModuleTranslation &translation) {
1680 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1681
1682 if (VecTypeHintAttr vecTypeHint = func.getVecTypeHintAttr()) {
1683 Type type = vecTypeHint.getHint().getValue();
1684 llvm::Type *llvmType = translation.convertType(type);
1685 bool isSigned = vecTypeHint.getIsSigned();
1686 llvmFunc->setMetadata(
1687 func.getVecTypeHintAttrName(),
1688 convertVecTypeHintToMDNode(context&: llvmContext, type: llvmType, isSigned));
1689 }
1690
1691 if (std::optional<ArrayRef<int32_t>> workGroupSizeHint =
1692 func.getWorkGroupSizeHint()) {
1693 llvmFunc->setMetadata(
1694 func.getWorkGroupSizeHintAttrName(),
1695 convertIntegerArrayToMDNode(context&: llvmContext, values: *workGroupSizeHint));
1696 }
1697
1698 if (std::optional<ArrayRef<int32_t>> reqdWorkGroupSize =
1699 func.getReqdWorkGroupSize()) {
1700 llvmFunc->setMetadata(
1701 func.getReqdWorkGroupSizeAttrName(),
1702 convertIntegerArrayToMDNode(context&: llvmContext, values: *reqdWorkGroupSize));
1703 }
1704
1705 if (std::optional<uint32_t> intelReqdSubGroupSize =
1706 func.getIntelReqdSubGroupSize()) {
1707 llvmFunc->setMetadata(
1708 func.getIntelReqdSubGroupSizeAttrName(),
1709 convertIntegerToMDNode(context&: llvmContext,
1710 value: llvm::APInt(32, *intelReqdSubGroupSize)));
1711 }
1712}
1713
1714static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder,
1715 llvm::Attribute::AttrKind llvmKind,
1716 NamedAttribute namedAttr,
1717 ModuleTranslation &moduleTranslation,
1718 Location loc) {
1719 return llvm::TypeSwitch<Attribute, LogicalResult>(namedAttr.getValue())
1720 .Case<TypeAttr>([&](auto typeAttr) {
1721 attrBuilder.addTypeAttr(
1722 llvmKind, moduleTranslation.convertType(typeAttr.getValue()));
1723 return success();
1724 })
1725 .Case<IntegerAttr>([&](auto intAttr) {
1726 attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
1727 return success();
1728 })
1729 .Case<UnitAttr>([&](auto) {
1730 attrBuilder.addAttribute(llvmKind);
1731 return success();
1732 })
1733 .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) {
1734 attrBuilder.addConstantRangeAttr(
1735 llvmKind,
1736 llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));
1737 return success();
1738 })
1739 .Default([loc](auto) {
1740 return emitError(loc, "unsupported parameter attribute type");
1741 });
1742}
1743
1744FailureOr<llvm::AttrBuilder>
1745ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
1746 DictionaryAttr paramAttrs) {
1747 llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1748 auto attrNameToKindMapping = getAttrNameToKindMapping();
1749 Location loc = func.getLoc();
1750
1751 for (auto namedAttr : paramAttrs) {
1752 auto it = attrNameToKindMapping.find(namedAttr.getName());
1753 if (it != attrNameToKindMapping.end()) {
1754 llvm::Attribute::AttrKind llvmKind = it->second;
1755 if (failed(convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this,
1756 loc)))
1757 return failure();
1758 } else if (namedAttr.getNameDialect()) {
1759 if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
1760 return failure();
1761 }
1762 }
1763
1764 return attrBuilder;
1765}
1766
1767FailureOr<llvm::AttrBuilder>
1768ModuleTranslation::convertParameterAttrs(Location loc,
1769 DictionaryAttr paramAttrs) {
1770 llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1771 auto attrNameToKindMapping = getAttrNameToKindMapping();
1772
1773 for (auto namedAttr : paramAttrs) {
1774 auto it = attrNameToKindMapping.find(namedAttr.getName());
1775 if (it != attrNameToKindMapping.end()) {
1776 llvm::Attribute::AttrKind llvmKind = it->second;
1777 if (failed(convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this,
1778 loc)))
1779 return failure();
1780 }
1781 }
1782
1783 return attrBuilder;
1784}
1785
1786LogicalResult ModuleTranslation::convertFunctionSignatures() {
1787 // Declare all functions first because there may be function calls that form a
1788 // call graph with cycles, or global initializers that reference functions.
1789 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1790 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
1791 function.getName(),
1792 cast<llvm::FunctionType>(convertType(function.getFunctionType())));
1793 llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
1794 llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage()));
1795 llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv()));
1796 mapFunction(function.getName(), llvmFunc);
1797 addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc);
1798
1799 // Convert function attributes.
1800 convertFunctionAttributes(function, llvmFunc);
1801
1802 // Convert function kernel attributes to metadata.
1803 convertFunctionKernelAttributes(function, llvmFunc, *this);
1804
1805 // Convert function_entry_count attribute to metadata.
1806 if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
1807 llvmFunc->setEntryCount(entryCount.value());
1808
1809 // Convert result attributes.
1810 if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
1811 DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1812 FailureOr<llvm::AttrBuilder> attrBuilder =
1813 convertParameterAttrs(function, -1, resultAttrs);
1814 if (failed(attrBuilder))
1815 return failure();
1816 llvmFunc->addRetAttrs(*attrBuilder);
1817 }
1818
1819 // Convert argument attributes.
1820 for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
1821 if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
1822 FailureOr<llvm::AttrBuilder> attrBuilder =
1823 convertParameterAttrs(function, argIdx, argAttrs);
1824 if (failed(attrBuilder))
1825 return failure();
1826 llvmArg.addAttrs(*attrBuilder);
1827 }
1828 }
1829
1830 // Forward the pass-through attributes to LLVM.
1831 if (failed(forwardPassthroughAttributes(
1832 function.getLoc(), function.getPassthrough(), llvmFunc)))
1833 return failure();
1834
1835 // Convert visibility attribute.
1836 llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_()));
1837
1838 // Convert the comdat attribute.
1839 if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) {
1840 auto selectorOp = cast<ComdatSelectorOp>(
1841 SymbolTable::lookupNearestSymbolFrom(function, *comdat));
1842 llvmFunc->setComdat(comdatMapping.lookup(selectorOp));
1843 }
1844
1845 if (auto gc = function.getGarbageCollector())
1846 llvmFunc->setGC(gc->str());
1847
1848 if (auto unnamedAddr = function.getUnnamedAddr())
1849 llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr));
1850
1851 if (auto alignment = function.getAlignment())
1852 llvmFunc->setAlignment(llvm::MaybeAlign(*alignment));
1853
1854 // Translate the debug information for this function.
1855 debugTranslation->translate(function, *llvmFunc);
1856 }
1857
1858 return success();
1859}
1860
1861LogicalResult ModuleTranslation::convertFunctions() {
1862 // Convert functions.
1863 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1864 // Do not convert external functions, but do process dialect attributes
1865 // attached to them.
1866 if (function.isExternal()) {
1867 if (failed(convertDialectAttributes(function, {})))
1868 return failure();
1869 continue;
1870 }
1871
1872 if (failed(convertOneFunction(function)))
1873 return failure();
1874 }
1875
1876 return success();
1877}
1878
1879LogicalResult ModuleTranslation::convertComdats() {
1880 for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) {
1881 for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
1882 llvm::Module *module = getLLVMModule();
1883 if (module->getComdatSymbolTable().contains(selectorOp.getSymName()))
1884 return emitError(selectorOp.getLoc())
1885 << "comdat selection symbols must be unique even in different "
1886 "comdat regions";
1887 llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName());
1888 comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat()));
1889 comdatMapping.try_emplace(selectorOp, comdat);
1890 }
1891 }
1892 return success();
1893}
1894
1895LogicalResult ModuleTranslation::convertUnresolvedBlockAddress() {
1896 for (auto &[blockAddressOp, llvmCst] : unresolvedBlockAddressMapping) {
1897 BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
1898 llvm::BasicBlock *llvmBlock = lookupBlockAddress(blockAddressAttr);
1899 assert(llvmBlock && "expected LLVM blocks to be already translated");
1900
1901 // Update mapping with new block address constant.
1902 auto *llvmBlockAddr = llvm::BlockAddress::get(
1903 lookupFunction(blockAddressAttr.getFunction().getValue()), llvmBlock);
1904 llvmCst->replaceAllUsesWith(llvmBlockAddr);
1905 assert(llvmCst->use_empty() && "expected all uses to be replaced");
1906 cast<llvm::GlobalVariable>(llvmCst)->eraseFromParent();
1907 }
1908 unresolvedBlockAddressMapping.clear();
1909 return success();
1910}
1911
1912void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op,
1913 llvm::Instruction *inst) {
1914 if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
1915 inst->setMetadata(KindID: llvm::LLVMContext::MD_access_group, Node: node);
1916}
1917
1918llvm::MDNode *
1919ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) {
1920 auto [scopeIt, scopeInserted] =
1921 aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr);
1922 if (!scopeInserted)
1923 return scopeIt->second;
1924 llvm::LLVMContext &ctx = llvmModule->getContext();
1925 auto dummy = llvm::MDNode::getTemporary(Context&: ctx, MDs: std::nullopt);
1926 // Convert the domain metadata node if necessary.
1927 auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace(
1928 aliasScopeAttr.getDomain(), nullptr);
1929 if (insertedDomain) {
1930 llvm::SmallVector<llvm::Metadata *, 2> operands;
1931 // Placeholder for potential self-reference.
1932 operands.push_back(Elt: dummy.get());
1933 if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
1934 operands.push_back(Elt: llvm::MDString::get(ctx, description));
1935 domainIt->second = llvm::MDNode::get(ctx, operands);
1936 // Self-reference for uniqueness.
1937 llvm::Metadata *replacement;
1938 if (auto stringAttr =
1939 dyn_cast<StringAttr>(aliasScopeAttr.getDomain().getId()))
1940 replacement = llvm::MDString::get(ctx, stringAttr.getValue());
1941 else
1942 replacement = domainIt->second;
1943 domainIt->second->replaceOperandWith(0, replacement);
1944 }
1945 // Convert the scope metadata node.
1946 assert(domainIt->second && "Scope's domain should already be valid");
1947 llvm::SmallVector<llvm::Metadata *, 3> operands;
1948 // Placeholder for potential self-reference.
1949 operands.push_back(Elt: dummy.get());
1950 operands.push_back(Elt: domainIt->second);
1951 if (StringAttr description = aliasScopeAttr.getDescription())
1952 operands.push_back(Elt: llvm::MDString::get(ctx, description));
1953 scopeIt->second = llvm::MDNode::get(ctx, operands);
1954 // Self-reference for uniqueness.
1955 llvm::Metadata *replacement;
1956 if (auto stringAttr = dyn_cast<StringAttr>(aliasScopeAttr.getId()))
1957 replacement = llvm::MDString::get(ctx, stringAttr.getValue());
1958 else
1959 replacement = scopeIt->second;
1960 scopeIt->second->replaceOperandWith(0, replacement);
1961 return scopeIt->second;
1962}
1963
1964llvm::MDNode *ModuleTranslation::getOrCreateAliasScopes(
1965 ArrayRef<AliasScopeAttr> aliasScopeAttrs) {
1966 SmallVector<llvm::Metadata *> nodes;
1967 nodes.reserve(N: aliasScopeAttrs.size());
1968 for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs)
1969 nodes.push_back(getOrCreateAliasScope(aliasScopeAttr));
1970 return llvm::MDNode::get(Context&: getLLVMContext(), MDs: nodes);
1971}
1972
1973void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
1974 llvm::Instruction *inst) {
1975 auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) {
1976 if (!aliasScopeAttrs || aliasScopeAttrs.empty())
1977 return;
1978 llvm::MDNode *node = getOrCreateAliasScopes(
1979 aliasScopeAttrs: llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>()));
1980 inst->setMetadata(KindID: kind, Node: node);
1981 };
1982
1983 populateScopeMetadata(op.getAliasScopesOrNull(),
1984 llvm::LLVMContext::MD_alias_scope);
1985 populateScopeMetadata(op.getNoAliasScopesOrNull(),
1986 llvm::LLVMContext::MD_noalias);
1987}
1988
1989llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
1990 return tbaaMetadataMapping.lookup(Val: tbaaAttr);
1991}
1992
1993void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
1994 llvm::Instruction *inst) {
1995 ArrayAttr tagRefs = op.getTBAATagsOrNull();
1996 if (!tagRefs || tagRefs.empty())
1997 return;
1998
1999 // LLVM IR currently does not support attaching more than one TBAA access tag
2000 // to a memory accessing instruction. It may be useful to support this in
2001 // future, but for the time being just ignore the metadata if MLIR operation
2002 // has multiple access tags.
2003 if (tagRefs.size() > 1) {
2004 op.emitWarning() << "TBAA access tags were not translated, because LLVM "
2005 "IR only supports a single tag per instruction";
2006 return;
2007 }
2008
2009 llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
2010 inst->setMetadata(KindID: llvm::LLVMContext::MD_tbaa, Node: node);
2011}
2012
2013void ModuleTranslation::setDereferenceableMetadata(
2014 DereferenceableOpInterface op, llvm::Instruction *inst) {
2015 DereferenceableAttr derefAttr = op.getDereferenceableOrNull();
2016 if (!derefAttr)
2017 return;
2018
2019 llvm::MDNode *derefSizeNode = llvm::MDNode::get(
2020 Context&: getLLVMContext(),
2021 MDs: llvm::ConstantAsMetadata::get(C: llvm::ConstantInt::get(
2022 llvm::IntegerType::get(C&: getLLVMContext(), NumBits: 64), derefAttr.getBytes())));
2023 unsigned kindId = derefAttr.getMayBeNull()
2024 ? llvm::LLVMContext::MD_dereferenceable_or_null
2025 : llvm::LLVMContext::MD_dereferenceable;
2026 inst->setMetadata(KindID: kindId, Node: derefSizeNode);
2027}
2028
2029void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
2030 DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
2031 if (!weightsAttr)
2032 return;
2033
2034 llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
2035 assert(inst && "expected the operation to have a mapping to an instruction");
2036 SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
2037 inst->setMetadata(
2038 KindID: llvm::LLVMContext::MD_prof,
2039 Node: llvm::MDBuilder(getLLVMContext()).createBranchWeights(Weights: weights));
2040}
2041
2042LogicalResult ModuleTranslation::createTBAAMetadata() {
2043 llvm::LLVMContext &ctx = llvmModule->getContext();
2044 llvm::IntegerType *offsetTy = llvm::IntegerType::get(C&: ctx, NumBits: 64);
2045
2046 // Walk the entire module and create all metadata nodes for the TBAA
2047 // attributes. The code below relies on two invariants of the
2048 // `AttrTypeWalker`:
2049 // 1. Attributes are visited in post-order: Since the attributes create a DAG,
2050 // this ensures that any lookups into `tbaaMetadataMapping` for child
2051 // attributes succeed.
2052 // 2. Attributes are only ever visited once: This way we don't leak any
2053 // LLVM metadata instances.
2054 AttrTypeWalker walker;
2055 walker.addWalk(callback: [&](TBAARootAttr root) {
2056 tbaaMetadataMapping.insert(
2057 {root, llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(ctx, root.getId()))});
2058 });
2059
2060 walker.addWalk(callback: [&](TBAATypeDescriptorAttr descriptor) {
2061 SmallVector<llvm::Metadata *> operands;
2062 operands.push_back(Elt: llvm::MDString::get(ctx, descriptor.getId()));
2063 for (TBAAMemberAttr member : descriptor.getMembers()) {
2064 operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc()));
2065 operands.push_back(llvm::ConstantAsMetadata::get(
2066 llvm::ConstantInt::get(offsetTy, member.getOffset())));
2067 }
2068
2069 tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(Context&: ctx, MDs: operands)});
2070 });
2071
2072 walker.addWalk(callback: [&](TBAATagAttr tag) {
2073 SmallVector<llvm::Metadata *> operands;
2074
2075 operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: tag.getBaseType()));
2076 operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: tag.getAccessType()));
2077
2078 operands.push_back(Elt: llvm::ConstantAsMetadata::get(
2079 C: llvm::ConstantInt::get(offsetTy, tag.getOffset())));
2080 if (tag.getConstant())
2081 operands.push_back(
2082 Elt: llvm::ConstantAsMetadata::get(C: llvm::ConstantInt::get(Ty: offsetTy, V: 1)));
2083
2084 tbaaMetadataMapping.insert({tag, llvm::MDNode::get(Context&: ctx, MDs: operands)});
2085 });
2086
2087 mlirModule->walk(callback: [&](AliasAnalysisOpInterface analysisOpInterface) {
2088 if (auto attr = analysisOpInterface.getTBAATagsOrNull())
2089 walker.walk(attr);
2090 });
2091
2092 return success();
2093}
2094
2095LogicalResult ModuleTranslation::createIdentMetadata() {
2096 if (auto attr = mlirModule->getAttrOfType<StringAttr>(
2097 LLVMDialect::getIdentAttrName())) {
2098 StringRef ident = attr;
2099 llvm::LLVMContext &ctx = llvmModule->getContext();
2100 llvm::NamedMDNode *namedMd =
2101 llvmModule->getOrInsertNamedMetadata(LLVMDialect::getIdentAttrName());
2102 llvm::MDNode *md = llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(Context&: ctx, Str: ident));
2103 namedMd->addOperand(M: md);
2104 }
2105
2106 return success();
2107}
2108
2109LogicalResult ModuleTranslation::createCommandlineMetadata() {
2110 if (auto attr = mlirModule->getAttrOfType<StringAttr>(
2111 LLVMDialect::getCommandlineAttrName())) {
2112 StringRef cmdLine = attr;
2113 llvm::LLVMContext &ctx = llvmModule->getContext();
2114 llvm::NamedMDNode *nmd = llvmModule->getOrInsertNamedMetadata(
2115 LLVMDialect::getCommandlineAttrName());
2116 llvm::MDNode *md =
2117 llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(Context&: ctx, Str: cmdLine));
2118 nmd->addOperand(M: md);
2119 }
2120
2121 return success();
2122}
2123
2124LogicalResult ModuleTranslation::createDependentLibrariesMetadata() {
2125 if (auto dependentLibrariesAttr = mlirModule->getDiscardableAttr(
2126 LLVM::LLVMDialect::getDependentLibrariesAttrName())) {
2127 auto *nmd =
2128 llvmModule->getOrInsertNamedMetadata(Name: "llvm.dependent-libraries");
2129 llvm::LLVMContext &ctx = llvmModule->getContext();
2130 for (auto libAttr :
2131 cast<ArrayAttr>(dependentLibrariesAttr).getAsRange<StringAttr>()) {
2132 auto *md =
2133 llvm::MDNode::get(ctx, llvm::MDString::get(ctx, libAttr.getValue()));
2134 nmd->addOperand(md);
2135 }
2136 }
2137 return success();
2138}
2139
2140void ModuleTranslation::setLoopMetadata(Operation *op,
2141 llvm::Instruction *inst) {
2142 LoopAnnotationAttr attr =
2143 TypeSwitch<Operation *, LoopAnnotationAttr>(op)
2144 .Case<LLVM::BrOp, LLVM::CondBrOp>(
2145 [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
2146 if (!attr)
2147 return;
2148 llvm::MDNode *loopMD =
2149 loopAnnotationTranslation->translateLoopAnnotation(attr, op);
2150 inst->setMetadata(KindID: llvm::LLVMContext::MD_loop, Node: loopMD);
2151}
2152
2153void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) {
2154 auto iface = cast<DisjointFlagInterface>(op);
2155 // We do a dyn_cast here in case the value got folded into a constant.
2156 if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(Val: value))
2157 disjointInst->setIsDisjoint(iface.getIsDisjoint());
2158}
2159
2160llvm::Type *ModuleTranslation::convertType(Type type) {
2161 return typeTranslator.translateType(type);
2162}
2163
2164/// A helper to look up remapped operands in the value remapping table.
2165SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) {
2166 SmallVector<llvm::Value *> remapped;
2167 remapped.reserve(N: values.size());
2168 for (Value v : values)
2169 remapped.push_back(Elt: lookupValue(value: v));
2170 return remapped;
2171}
2172
2173llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
2174 if (!ompBuilder) {
2175 ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(args&: *llvmModule);
2176 ompBuilder->initialize();
2177
2178 // Flags represented as top-level OpenMP dialect attributes are set in
2179 // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set
2180 // the default configuration.
2181 ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig(
2182 /* IsTargetDevice = */ false, /* IsGPU = */ false,
2183 /* OpenMPOffloadMandatory = */ false,
2184 /* HasRequiresReverseOffload = */ false,
2185 /* HasRequiresUnifiedAddress = */ false,
2186 /* HasRequiresUnifiedSharedMemory = */ false,
2187 /* HasRequiresDynamicAllocators = */ false));
2188 }
2189 return ompBuilder.get();
2190}
2191
2192llvm::DILocation *ModuleTranslation::translateLoc(Location loc,
2193 llvm::DILocalScope *scope) {
2194 return debugTranslation->translateLoc(loc, scope);
2195}
2196
2197llvm::DIExpression *
2198ModuleTranslation::translateExpression(LLVM::DIExpressionAttr attr) {
2199 return debugTranslation->translateExpression(attr);
2200}
2201
2202llvm::DIGlobalVariableExpression *
2203ModuleTranslation::translateGlobalVariableExpression(
2204 LLVM::DIGlobalVariableExpressionAttr attr) {
2205 return debugTranslation->translateGlobalVariableExpression(attr);
2206}
2207
2208llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
2209 return debugTranslation->translate(attr);
2210}
2211
2212llvm::RoundingMode
2213ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
2214 return convertRoundingModeToLLVM(rounding);
2215}
2216
2217llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior(
2218 LLVM::FPExceptionBehavior exceptionBehavior) {
2219 return convertFPExceptionBehaviorToLLVM(exceptionBehavior);
2220}
2221
2222llvm::NamedMDNode *
2223ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
2224 return llvmModule->getOrInsertNamedMetadata(Name: name);
2225}
2226
2227void ModuleTranslation::StackFrame::anchor() {}
2228
2229static std::unique_ptr<llvm::Module>
2230prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
2231 StringRef name) {
2232 m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
2233 auto llvmModule = std::make_unique<llvm::Module>(args&: name, args&: llvmContext);
2234 // ModuleTranslation can currently only construct modules in the old debug
2235 // info format, so set the flag accordingly.
2236 llvmModule->setNewDbgInfoFormatFlag(false);
2237 if (auto dataLayoutAttr =
2238 m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
2239 llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue());
2240 } else {
2241 FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
2242 if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
2243 if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
2244 llvmDataLayout =
2245 translateDataLayout(spec, DataLayout(iface), m->getLoc());
2246 }
2247 } else if (auto mod = dyn_cast<ModuleOp>(m)) {
2248 if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
2249 llvmDataLayout =
2250 translateDataLayout(spec, DataLayout(mod), m->getLoc());
2251 }
2252 }
2253 if (failed(Result: llvmDataLayout))
2254 return nullptr;
2255 llvmModule->setDataLayout(*llvmDataLayout);
2256 }
2257 if (auto targetTripleAttr =
2258 m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
2259 llvmModule->setTargetTriple(
2260 llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue()));
2261
2262 return llvmModule;
2263}
2264
2265std::unique_ptr<llvm::Module>
2266mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
2267 StringRef name, bool disableVerification) {
2268 if (!satisfiesLLVMModule(op: module)) {
2269 module->emitOpError(message: "can not be translated to an LLVMIR module");
2270 return nullptr;
2271 }
2272
2273 std::unique_ptr<llvm::Module> llvmModule =
2274 prepareLLVMModule(m: module, llvmContext, name);
2275 if (!llvmModule)
2276 return nullptr;
2277
2278 LLVM::ensureDistinctSuccessors(op: module);
2279 LLVM::legalizeDIExpressionsRecursively(op: module);
2280
2281 ModuleTranslation translator(module, std::move(llvmModule));
2282 llvm::IRBuilder<llvm::TargetFolder> llvmBuilder(
2283 llvmContext,
2284 llvm::TargetFolder(translator.getLLVMModule()->getDataLayout()));
2285
2286 // Convert module before functions and operations inside, so dialect
2287 // attributes can be used to change dialect-specific global configurations via
2288 // `amendOperation()`. These configurations can then influence the translation
2289 // of operations afterwards.
2290 if (failed(Result: translator.convertOperation(op&: *module, builder&: llvmBuilder)))
2291 return nullptr;
2292
2293 if (failed(Result: translator.convertComdats()))
2294 return nullptr;
2295 if (failed(Result: translator.convertFunctionSignatures()))
2296 return nullptr;
2297 if (failed(Result: translator.convertGlobalsAndAliases()))
2298 return nullptr;
2299 if (failed(Result: translator.createTBAAMetadata()))
2300 return nullptr;
2301 if (failed(Result: translator.createIdentMetadata()))
2302 return nullptr;
2303 if (failed(Result: translator.createCommandlineMetadata()))
2304 return nullptr;
2305 if (failed(Result: translator.createDependentLibrariesMetadata()))
2306 return nullptr;
2307
2308 // Convert other top-level operations if possible.
2309 for (Operation &o : getModuleBody(module).getOperations()) {
2310 if (!isa<LLVM::LLVMFuncOp, LLVM::AliasOp, LLVM::GlobalOp,
2311 LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) &&
2312 !o.hasTrait<OpTrait::IsTerminator>() &&
2313 failed(translator.convertOperation(o, llvmBuilder))) {
2314 return nullptr;
2315 }
2316 }
2317
2318 // Operations in function bodies with symbolic references must be converted
2319 // after the top-level operations they refer to are declared, so we do it
2320 // last.
2321 if (failed(Result: translator.convertFunctions()))
2322 return nullptr;
2323
2324 // Now that all MLIR blocks are resolved into LLVM ones, patch block address
2325 // constants to point to the correct blocks.
2326 if (failed(Result: translator.convertUnresolvedBlockAddress()))
2327 return nullptr;
2328
2329 // Once we've finished constructing elements in the module, we should convert
2330 // it to use the debug info format desired by LLVM.
2331 // See https://llvm.org/docs/RemoveDIsDebugInfo.html
2332 translator.llvmModule->setIsNewDbgInfoFormat(true);
2333
2334 // Add the necessary debug info module flags, if they were not encoded in MLIR
2335 // beforehand.
2336 translator.debugTranslation->addModuleFlagsIfNotPresent();
2337
2338 if (!disableVerification &&
2339 llvm::verifyModule(M: *translator.llvmModule, OS: &llvm::errs()))
2340 return nullptr;
2341
2342 return std::move(translator.llvmModule);
2343}
2344

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Target/LLVMIR/ModuleTranslation.cpp