1//===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the translation between an MLIR LLVM dialect module and
10// the corresponding LLVMIR module. It only handles core LLVM IR operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Target/LLVMIR/ModuleTranslation.h"
15
16#include "AttrKindDetail.h"
17#include "DebugTranslation.h"
18#include "LoopAnnotationTranslation.h"
19#include "mlir/Analysis/TopologicalSortUtils.h"
20#include "mlir/Dialect/DLTI/DLTI.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
23#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
24#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
25#include "mlir/IR/AttrTypeSubElements.h"
26#include "mlir/IR/Attributes.h"
27#include "mlir/IR/BuiltinOps.h"
28#include "mlir/IR/BuiltinTypes.h"
29#include "mlir/IR/DialectResourceBlobManager.h"
30#include "mlir/Support/LLVM.h"
31#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
32#include "mlir/Target/LLVMIR/TypeToLLVM.h"
33
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Analysis/TargetFolder.h"
38#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/CFG.h"
41#include "llvm/IR/Constants.h"
42#include "llvm/IR/DerivedTypes.h"
43#include "llvm/IR/IRBuilder.h"
44#include "llvm/IR/InlineAsm.h"
45#include "llvm/IR/LLVMContext.h"
46#include "llvm/IR/MDBuilder.h"
47#include "llvm/IR/Module.h"
48#include "llvm/IR/Verifier.h"
49#include "llvm/Support/Debug.h"
50#include "llvm/Support/ErrorHandling.h"
51#include "llvm/Support/raw_ostream.h"
52#include "llvm/Transforms/Utils/BasicBlockUtils.h"
53#include "llvm/Transforms/Utils/Cloning.h"
54#include "llvm/Transforms/Utils/ModuleUtils.h"
55#include <numeric>
56#include <optional>
57
58#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
59
60using namespace mlir;
61using namespace mlir::LLVM;
62using namespace mlir::LLVM::detail;
63
64#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
65
66namespace {
67/// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
68/// instructions that are created for future reference.
69///
70/// This is intended to be used with the `CollectionScope` RAII object:
71///
72/// llvm::IRBuilder<..., InstructionCapturingInserter> builder;
73/// {
74/// InstructionCapturingInserter::CollectionScope scope(builder);
75/// // Call IRBuilder methods as usual.
76///
77/// // This will return a list of all instructions created by the builder,
78/// // in order of creation.
79/// builder.getInserter().getCapturedInstructions();
80/// }
81/// // This will return an empty list.
82/// builder.getInserter().getCapturedInstructions();
83///
84/// The capturing functionality is _disabled_ by default for performance
85/// consideration. It needs to be explicitly enabled, which is achieved by
86/// creating a `CollectionScope`.
87class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
88public:
89 /// Constructs the inserter.
90 InstructionCapturingInserter()
91 : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
92 if (LLVM_LIKELY(enabled))
93 capturedInstructions.push_back(Elt: instruction);
94 }) {}
95
96 /// Returns the list of LLVM IR instructions captured since the last cleanup.
97 ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
98 return capturedInstructions;
99 }
100
101 /// Clears the list of captured LLVM IR instructions.
102 void clearCapturedInstructions() { capturedInstructions.clear(); }
103
104 /// RAII object enabling the capture of created LLVM IR instructions.
105 class CollectionScope {
106 public:
107 /// Creates the scope for the given inserter.
108 CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing);
109
110 /// Ends the scope.
111 ~CollectionScope();
112
113 ArrayRef<llvm::Instruction *> getCapturedInstructions() {
114 if (!inserter)
115 return {};
116 return inserter->getCapturedInstructions();
117 }
118
119 private:
120 /// Back reference to the inserter.
121 InstructionCapturingInserter *inserter = nullptr;
122
123 /// List of instructions in the inserter prior to this scope.
124 SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
125
126 /// Whether the inserter was enabled prior to this scope.
127 bool wasEnabled;
128 };
129
130 /// Enable or disable the capturing mechanism.
131 void setEnabled(bool enabled = true) { this->enabled = enabled; }
132
133private:
134 /// List of captured instructions.
135 SmallVector<llvm::Instruction *> capturedInstructions;
136
137 /// Whether the collection is enabled.
138 bool enabled = false;
139};
140
141using CapturingIRBuilder =
142 llvm::IRBuilder<llvm::TargetFolder, InstructionCapturingInserter>;
143} // namespace
144
145InstructionCapturingInserter::CollectionScope::CollectionScope(
146 llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) {
147
148 if (!isBuilderCapturing)
149 return;
150
151 auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder);
152 inserter = &capturingIRBuilder.getInserter();
153 wasEnabled = inserter->enabled;
154 if (wasEnabled)
155 previouslyCollectedInstructions.swap(RHS&: inserter->capturedInstructions);
156 inserter->setEnabled(true);
157}
158
159InstructionCapturingInserter::CollectionScope::~CollectionScope() {
160 if (!inserter)
161 return;
162
163 previouslyCollectedInstructions.swap(RHS&: inserter->capturedInstructions);
164 // If collection was enabled (likely in another, surrounding scope), keep
165 // the instructions collected in this scope.
166 if (wasEnabled) {
167 llvm::append_range(C&: inserter->capturedInstructions,
168 R&: previouslyCollectedInstructions);
169 }
170 inserter->setEnabled(wasEnabled);
171}
172
173/// Translates the given data layout spec attribute to the LLVM IR data layout.
174/// Only integer, float, pointer and endianness entries are currently supported.
175static FailureOr<llvm::DataLayout>
176translateDataLayout(DataLayoutSpecInterface attribute,
177 const DataLayout &dataLayout,
178 std::optional<Location> loc = std::nullopt) {
179 if (!loc)
180 loc = UnknownLoc::get(context: attribute.getContext());
181
182 // Translate the endianness attribute.
183 std::string llvmDataLayout;
184 llvm::raw_string_ostream layoutStream(llvmDataLayout);
185 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
186 auto key = llvm::dyn_cast_if_present<StringAttr>(Val: entry.getKey());
187 if (!key)
188 continue;
189 if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
190 auto value = cast<StringAttr>(Val: entry.getValue());
191 bool isLittleEndian =
192 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
193 layoutStream << "-" << (isLittleEndian ? "e" : "E");
194 continue;
195 }
196 if (key.getValue() == DLTIDialect::kDataLayoutManglingModeKey) {
197 auto value = cast<StringAttr>(Val: entry.getValue());
198 layoutStream << "-m:" << value.getValue();
199 continue;
200 }
201 if (key.getValue() == DLTIDialect::kDataLayoutProgramMemorySpaceKey) {
202 auto value = cast<IntegerAttr>(Val: entry.getValue());
203 uint64_t space = value.getValue().getZExtValue();
204 // Skip the default address space.
205 if (space == 0)
206 continue;
207 layoutStream << "-P" << space;
208 continue;
209 }
210 if (key.getValue() == DLTIDialect::kDataLayoutGlobalMemorySpaceKey) {
211 auto value = cast<IntegerAttr>(Val: entry.getValue());
212 uint64_t space = value.getValue().getZExtValue();
213 // Skip the default address space.
214 if (space == 0)
215 continue;
216 layoutStream << "-G" << space;
217 continue;
218 }
219 if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
220 auto value = cast<IntegerAttr>(Val: entry.getValue());
221 uint64_t space = value.getValue().getZExtValue();
222 // Skip the default address space.
223 if (space == 0)
224 continue;
225 layoutStream << "-A" << space;
226 continue;
227 }
228 if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
229 auto value = cast<IntegerAttr>(Val: entry.getValue());
230 uint64_t alignment = value.getValue().getZExtValue();
231 // Skip the default stack alignment.
232 if (alignment == 0)
233 continue;
234 layoutStream << "-S" << alignment;
235 continue;
236 }
237 if (key.getValue() == DLTIDialect::kDataLayoutFunctionPointerAlignmentKey) {
238 auto value = cast<FunctionPointerAlignmentAttr>(Val: entry.getValue());
239 uint64_t alignment = value.getAlignment();
240 // Skip the default function pointer alignment.
241 if (alignment == 0)
242 continue;
243 layoutStream << "-F" << (value.getFunctionDependent() ? "n" : "i")
244 << alignment;
245 continue;
246 }
247 if (key.getValue() == DLTIDialect::kDataLayoutLegalIntWidthsKey) {
248 layoutStream << "-n";
249 llvm::interleave(
250 c: cast<DenseI32ArrayAttr>(Val: entry.getValue()).asArrayRef(), os&: layoutStream,
251 each_fn: [&](int32_t val) { layoutStream << val; }, separator: ":");
252 continue;
253 }
254 emitError(loc: *loc) << "unsupported data layout key " << key;
255 return failure();
256 }
257
258 // Go through the list of entries to check which types are explicitly
259 // specified in entries. Where possible, data layout queries are used instead
260 // of directly inspecting the entries.
261 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
262 auto type = llvm::dyn_cast_if_present<Type>(Val: entry.getKey());
263 if (!type)
264 continue;
265 // Data layout for the index type is irrelevant at this point.
266 if (isa<IndexType>(Val: type))
267 continue;
268 layoutStream << "-";
269 LogicalResult result =
270 llvm::TypeSwitch<Type, LogicalResult>(type)
271 .Case<IntegerType, Float16Type, Float32Type, Float64Type,
272 Float80Type, Float128Type>(caseFn: [&](Type type) -> LogicalResult {
273 if (auto intType = dyn_cast<IntegerType>(Val&: type)) {
274 if (intType.getSignedness() != IntegerType::Signless)
275 return emitError(loc: *loc)
276 << "unsupported data layout for non-signless integer "
277 << intType;
278 layoutStream << "i";
279 } else {
280 layoutStream << "f";
281 }
282 uint64_t size = dataLayout.getTypeSizeInBits(t: type);
283 uint64_t abi = dataLayout.getTypeABIAlignment(t: type) * 8u;
284 uint64_t preferred =
285 dataLayout.getTypePreferredAlignment(t: type) * 8u;
286 layoutStream << size << ":" << abi;
287 if (abi != preferred)
288 layoutStream << ":" << preferred;
289 return success();
290 })
291 .Case(caseFn: [&](LLVMPointerType type) {
292 layoutStream << "p" << type.getAddressSpace() << ":";
293 uint64_t size = dataLayout.getTypeSizeInBits(t: type);
294 uint64_t abi = dataLayout.getTypeABIAlignment(t: type) * 8u;
295 uint64_t preferred =
296 dataLayout.getTypePreferredAlignment(t: type) * 8u;
297 uint64_t index = *dataLayout.getTypeIndexBitwidth(t: type);
298 layoutStream << size << ":" << abi << ":" << preferred << ":"
299 << index;
300 return success();
301 })
302 .Default(defaultFn: [loc](Type type) {
303 return emitError(loc: *loc)
304 << "unsupported type in data layout: " << type;
305 });
306 if (failed(Result: result))
307 return failure();
308 }
309 StringRef layoutSpec(llvmDataLayout);
310 layoutSpec.consume_front(Prefix: "-");
311
312 return llvm::DataLayout(layoutSpec);
313}
314
315/// Builds a constant of a sequential LLVM type `type`, potentially containing
316/// other sequential types recursively, from the individual constant values
317/// provided in `constants`. `shape` contains the number of elements in nested
318/// sequential types. Reports errors at `loc` and returns nullptr on error.
319static llvm::Constant *
320buildSequentialConstant(ArrayRef<llvm::Constant *> &constants,
321 ArrayRef<int64_t> shape, llvm::Type *type,
322 Location loc) {
323 if (shape.empty()) {
324 llvm::Constant *result = constants.front();
325 constants = constants.drop_front();
326 return result;
327 }
328
329 llvm::Type *elementType;
330 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: type)) {
331 elementType = arrayTy->getElementType();
332 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(Val: type)) {
333 elementType = vectorTy->getElementType();
334 } else {
335 emitError(loc) << "expected sequential LLVM types wrapping a scalar";
336 return nullptr;
337 }
338
339 SmallVector<llvm::Constant *, 8> nested;
340 nested.reserve(N: shape.front());
341 for (int64_t i = 0; i < shape.front(); ++i) {
342 nested.push_back(Elt: buildSequentialConstant(constants, shape: shape.drop_front(),
343 type: elementType, loc));
344 if (!nested.back())
345 return nullptr;
346 }
347
348 if (shape.size() == 1 && type->isVectorTy())
349 return llvm::ConstantVector::get(V: nested);
350 return llvm::ConstantArray::get(
351 T: llvm::ArrayType::get(ElementType: elementType, NumElements: shape.front()), V: nested);
352}
353
354/// Returns the first non-sequential type nested in sequential types.
355static llvm::Type *getInnermostElementType(llvm::Type *type) {
356 do {
357 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: type)) {
358 type = arrayTy->getElementType();
359 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(Val: type)) {
360 type = vectorTy->getElementType();
361 } else {
362 return type;
363 }
364 } while (true);
365}
366
367/// Convert a dense elements attribute to an LLVM IR constant using its raw data
368/// storage if possible. This supports elements attributes of tensor or vector
369/// type and avoids constructing separate objects for individual values of the
370/// innermost dimension. Constants for other dimensions are still constructed
371/// recursively. Returns null if constructing from raw data is not supported for
372/// this type, e.g., element type is not a power-of-two-sized primitive. Reports
373/// other errors at `loc`.
374static llvm::Constant *
375convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
376 llvm::Type *llvmType,
377 const ModuleTranslation &moduleTranslation) {
378 if (!denseElementsAttr)
379 return nullptr;
380
381 llvm::Type *innermostLLVMType = getInnermostElementType(type: llvmType);
382 if (!llvm::ConstantDataSequential::isElementTypeCompatible(Ty: innermostLLVMType))
383 return nullptr;
384
385 ShapedType type = denseElementsAttr.getType();
386 if (type.getNumElements() == 0)
387 return nullptr;
388
389 // Check that the raw data size matches what is expected for the scalar size.
390 // TODO: in theory, we could repack the data here to keep constructing from
391 // raw data.
392 // TODO: we may also need to consider endianness when cross-compiling to an
393 // architecture where it is different.
394 int64_t elementByteSize = denseElementsAttr.getRawData().size() /
395 denseElementsAttr.getNumElements();
396 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits())
397 return nullptr;
398
399 // Compute the shape of all dimensions but the innermost. Note that the
400 // innermost dimension may be that of the vector element type.
401 bool hasVectorElementType = isa<VectorType>(Val: type.getElementType());
402 int64_t numAggregates =
403 denseElementsAttr.getNumElements() /
404 (hasVectorElementType ? 1
405 : denseElementsAttr.getType().getShape().back());
406 ArrayRef<int64_t> outerShape = type.getShape();
407 if (!hasVectorElementType)
408 outerShape = outerShape.drop_back();
409
410 // Handle the case of vector splat, LLVM has special support for it.
411 if (denseElementsAttr.isSplat() &&
412 (isa<VectorType>(Val: type) || hasVectorElementType)) {
413 llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
414 llvmType: innermostLLVMType, attr: denseElementsAttr.getSplatValue<Attribute>(), loc,
415 moduleTranslation);
416 llvm::Constant *splatVector =
417 llvm::ConstantDataVector::getSplat(NumElts: 0, Elt: splatValue);
418 SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
419 ArrayRef<llvm::Constant *> constantsRef = constants;
420 return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc);
421 }
422 if (denseElementsAttr.isSplat())
423 return nullptr;
424
425 // In case of non-splat, create a constructor for the innermost constant from
426 // a piece of raw data.
427 std::function<llvm::Constant *(StringRef)> buildCstData;
428 if (isa<TensorType>(Val: type)) {
429 auto vectorElementType = dyn_cast<VectorType>(Val: type.getElementType());
430 if (vectorElementType && vectorElementType.getRank() == 1) {
431 buildCstData = [&](StringRef data) {
432 return llvm::ConstantDataVector::getRaw(
433 Data: data, NumElements: vectorElementType.getShape().back(), ElementTy: innermostLLVMType);
434 };
435 } else if (!vectorElementType) {
436 buildCstData = [&](StringRef data) {
437 return llvm::ConstantDataArray::getRaw(Data: data, NumElements: type.getShape().back(),
438 ElementTy: innermostLLVMType);
439 };
440 }
441 } else if (isa<VectorType>(Val: type)) {
442 buildCstData = [&](StringRef data) {
443 return llvm::ConstantDataVector::getRaw(Data: data, NumElements: type.getShape().back(),
444 ElementTy: innermostLLVMType);
445 };
446 }
447 if (!buildCstData)
448 return nullptr;
449
450 // Create innermost constants and defer to the default constant creation
451 // mechanism for other dimensions.
452 SmallVector<llvm::Constant *> constants;
453 int64_t aggregateSize = denseElementsAttr.getType().getShape().back() *
454 (innermostLLVMType->getScalarSizeInBits() / 8);
455 constants.reserve(N: numAggregates);
456 for (unsigned i = 0; i < numAggregates; ++i) {
457 StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
458 aggregateSize);
459 constants.push_back(Elt: buildCstData(data));
460 }
461
462 ArrayRef<llvm::Constant *> constantsRef = constants;
463 return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc);
464}
465
466/// Convert a dense resource elements attribute to an LLVM IR constant using its
467/// raw data storage if possible. This supports elements attributes of tensor or
468/// vector type and avoids constructing separate objects for individual values
469/// of the innermost dimension. Constants for other dimensions are still
470/// constructed recursively. Returns nullptr on failure and emits errors at
471/// `loc`.
472static llvm::Constant *convertDenseResourceElementsAttr(
473 Location loc, DenseResourceElementsAttr denseResourceAttr,
474 llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) {
475 assert(denseResourceAttr && "expected non-null attribute");
476
477 llvm::Type *innermostLLVMType = getInnermostElementType(type: llvmType);
478 if (!llvm::ConstantDataSequential::isElementTypeCompatible(
479 Ty: innermostLLVMType)) {
480 emitError(loc, message: "no known conversion for innermost element type");
481 return nullptr;
482 }
483
484 ShapedType type = denseResourceAttr.getType();
485 assert(type.getNumElements() > 0 && "Expected non-empty elements attribute");
486
487 AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob();
488 if (!blob) {
489 emitError(loc, message: "resource does not exist");
490 return nullptr;
491 }
492
493 ArrayRef<char> rawData = blob->getData();
494
495 // Check that the raw data size matches what is expected for the scalar size.
496 // TODO: in theory, we could repack the data here to keep constructing from
497 // raw data.
498 // TODO: we may also need to consider endianness when cross-compiling to an
499 // architecture where it is different.
500 int64_t numElements = denseResourceAttr.getType().getNumElements();
501 int64_t elementByteSize = rawData.size() / numElements;
502 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
503 emitError(loc, message: "raw data size does not match element type size");
504 return nullptr;
505 }
506
507 // Compute the shape of all dimensions but the innermost. Note that the
508 // innermost dimension may be that of the vector element type.
509 bool hasVectorElementType = isa<VectorType>(Val: type.getElementType());
510 int64_t numAggregates =
511 numElements / (hasVectorElementType
512 ? 1
513 : denseResourceAttr.getType().getShape().back());
514 ArrayRef<int64_t> outerShape = type.getShape();
515 if (!hasVectorElementType)
516 outerShape = outerShape.drop_back();
517
518 // Create a constructor for the innermost constant from a piece of raw data.
519 std::function<llvm::Constant *(StringRef)> buildCstData;
520 if (isa<TensorType>(Val: type)) {
521 auto vectorElementType = dyn_cast<VectorType>(Val: type.getElementType());
522 if (vectorElementType && vectorElementType.getRank() == 1) {
523 buildCstData = [&](StringRef data) {
524 return llvm::ConstantDataVector::getRaw(
525 Data: data, NumElements: vectorElementType.getShape().back(), ElementTy: innermostLLVMType);
526 };
527 } else if (!vectorElementType) {
528 buildCstData = [&](StringRef data) {
529 return llvm::ConstantDataArray::getRaw(Data: data, NumElements: type.getShape().back(),
530 ElementTy: innermostLLVMType);
531 };
532 }
533 } else if (isa<VectorType>(Val: type)) {
534 buildCstData = [&](StringRef data) {
535 return llvm::ConstantDataVector::getRaw(Data: data, NumElements: type.getShape().back(),
536 ElementTy: innermostLLVMType);
537 };
538 }
539 if (!buildCstData) {
540 emitError(loc, message: "unsupported dense_resource type");
541 return nullptr;
542 }
543
544 // Create innermost constants and defer to the default constant creation
545 // mechanism for other dimensions.
546 SmallVector<llvm::Constant *> constants;
547 int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
548 (innermostLLVMType->getScalarSizeInBits() / 8);
549 constants.reserve(N: numAggregates);
550 for (unsigned i = 0; i < numAggregates; ++i) {
551 StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
552 constants.push_back(Elt: buildCstData(data));
553 }
554
555 ArrayRef<llvm::Constant *> constantsRef = constants;
556 return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc);
557}
558
559/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
560/// This currently supports integer, floating point, splat and dense element
561/// attributes and combinations thereof. Also, an array attribute with two
562/// elements is supported to represent a complex constant. In case of error,
563/// report it to `loc` and return nullptr.
564llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
565 llvm::Type *llvmType, Attribute attr, Location loc,
566 const ModuleTranslation &moduleTranslation) {
567 if (!attr || isa<UndefAttr>(Val: attr))
568 return llvm::UndefValue::get(T: llvmType);
569 if (isa<ZeroAttr>(Val: attr))
570 return llvm::Constant::getNullValue(Ty: llvmType);
571 if (auto *structType = dyn_cast<::llvm::StructType>(Val: llvmType)) {
572 auto arrayAttr = dyn_cast<ArrayAttr>(Val&: attr);
573 if (!arrayAttr) {
574 emitError(loc, message: "expected an array attribute for a struct constant");
575 return nullptr;
576 }
577 SmallVector<llvm::Constant *> structElements;
578 structElements.reserve(N: structType->getNumElements());
579 for (auto [elemType, elemAttr] :
580 zip_equal(t: structType->elements(), u&: arrayAttr)) {
581 llvm::Constant *element =
582 getLLVMConstant(llvmType: elemType, attr: elemAttr, loc, moduleTranslation);
583 if (!element)
584 return nullptr;
585 structElements.push_back(Elt: element);
586 }
587 return llvm::ConstantStruct::get(T: structType, V: structElements);
588 }
589 // For integer types, we allow a mismatch in sizes as the index type in
590 // MLIR might have a different size than the index type in the LLVM module.
591 if (auto intAttr = dyn_cast<IntegerAttr>(Val&: attr))
592 return llvm::ConstantInt::get(
593 Ty: llvmType,
594 V: intAttr.getValue().sextOrTrunc(width: llvmType->getIntegerBitWidth()));
595 if (auto floatAttr = dyn_cast<FloatAttr>(Val&: attr)) {
596 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
597 // Special case for 8-bit floats, which are represented by integers due to
598 // the lack of native fp8 types in LLVM at the moment. Additionally, handle
599 // targets (like AMDGPU) that don't implement bfloat and convert all bfloats
600 // to i16.
601 unsigned floatWidth = APFloat::getSizeInBits(Sem: sem);
602 if (llvmType->isIntegerTy(Bitwidth: floatWidth))
603 return llvm::ConstantInt::get(Ty: llvmType,
604 V: floatAttr.getValue().bitcastToAPInt());
605 if (llvmType !=
606 llvm::Type::getFloatingPointTy(C&: llvmType->getContext(),
607 S: floatAttr.getValue().getSemantics())) {
608 emitError(loc, message: "FloatAttr does not match expected type of the constant");
609 return nullptr;
610 }
611 return llvm::ConstantFP::get(Ty: llvmType, V: floatAttr.getValue());
612 }
613 if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(Val&: attr))
614 return llvm::ConstantExpr::getBitCast(
615 C: moduleTranslation.lookupFunction(name: funcAttr.getValue()), Ty: llvmType);
616 if (auto splatAttr = dyn_cast<SplatElementsAttr>(Val&: attr)) {
617 llvm::Type *elementType;
618 uint64_t numElements;
619 bool isScalable = false;
620 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: llvmType)) {
621 elementType = arrayTy->getElementType();
622 numElements = arrayTy->getNumElements();
623 } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(Val: llvmType)) {
624 elementType = fVectorTy->getElementType();
625 numElements = fVectorTy->getNumElements();
626 } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(Val: llvmType)) {
627 elementType = sVectorTy->getElementType();
628 numElements = sVectorTy->getMinNumElements();
629 isScalable = true;
630 } else {
631 llvm_unreachable("unrecognized constant vector type");
632 }
633 // Splat value is a scalar. Extract it only if the element type is not
634 // another sequence type. The recursion terminates because each step removes
635 // one outer sequential type.
636 bool elementTypeSequential =
637 isa<llvm::ArrayType, llvm::VectorType>(Val: elementType);
638 llvm::Constant *child = getLLVMConstant(
639 llvmType: elementType,
640 attr: elementTypeSequential ? splatAttr
641 : splatAttr.getSplatValue<Attribute>(),
642 loc, moduleTranslation);
643 if (!child)
644 return nullptr;
645 if (llvmType->isVectorTy())
646 return llvm::ConstantVector::getSplat(
647 EC: llvm::ElementCount::get(MinVal: numElements, /*Scalable=*/isScalable), Elt: child);
648 if (llvmType->isArrayTy()) {
649 auto *arrayType = llvm::ArrayType::get(ElementType: elementType, NumElements: numElements);
650 if (child->isZeroValue()) {
651 return llvm::ConstantAggregateZero::get(Ty: arrayType);
652 } else {
653 if (llvm::ConstantDataSequential::isElementTypeCompatible(
654 Ty: elementType)) {
655 // TODO: Handle all compatible types. This code only handles integer.
656 if (isa<llvm::IntegerType>(Val: elementType)) {
657 if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(Val: child)) {
658 if (ci->getBitWidth() == 8) {
659 SmallVector<int8_t> constants(numElements, ci->getZExtValue());
660 return llvm::ConstantDataArray::get(Context&: elementType->getContext(),
661 Elts&: constants);
662 }
663 if (ci->getBitWidth() == 16) {
664 SmallVector<int16_t> constants(numElements, ci->getZExtValue());
665 return llvm::ConstantDataArray::get(Context&: elementType->getContext(),
666 Elts&: constants);
667 }
668 if (ci->getBitWidth() == 32) {
669 SmallVector<int32_t> constants(numElements, ci->getZExtValue());
670 return llvm::ConstantDataArray::get(Context&: elementType->getContext(),
671 Elts&: constants);
672 }
673 if (ci->getBitWidth() == 64) {
674 SmallVector<int64_t> constants(numElements, ci->getZExtValue());
675 return llvm::ConstantDataArray::get(Context&: elementType->getContext(),
676 Elts&: constants);
677 }
678 }
679 }
680 }
681 // std::vector is used here to accomodate large number of elements that
682 // exceed SmallVector capacity.
683 std::vector<llvm::Constant *> constants(numElements, child);
684 return llvm::ConstantArray::get(T: arrayType, V: constants);
685 }
686 }
687 }
688
689 // Try using raw elements data if possible.
690 if (llvm::Constant *result =
691 convertDenseElementsAttr(loc, denseElementsAttr: dyn_cast<DenseElementsAttr>(Val&: attr),
692 llvmType, moduleTranslation)) {
693 return result;
694 }
695
696 if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(Val&: attr)) {
697 return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType,
698 moduleTranslation);
699 }
700
701 // Fall back to element-by-element construction otherwise.
702 if (auto elementsAttr = dyn_cast<ElementsAttr>(Val&: attr)) {
703 assert(elementsAttr.getShapedType().hasStaticShape());
704 assert(!elementsAttr.getShapedType().getShape().empty() &&
705 "unexpected empty elements attribute shape");
706
707 SmallVector<llvm::Constant *, 8> constants;
708 constants.reserve(N: elementsAttr.getNumElements());
709 llvm::Type *innermostType = getInnermostElementType(type: llvmType);
710 for (auto n : elementsAttr.getValues<Attribute>()) {
711 constants.push_back(
712 Elt: getLLVMConstant(llvmType: innermostType, attr: n, loc, moduleTranslation));
713 if (!constants.back())
714 return nullptr;
715 }
716 ArrayRef<llvm::Constant *> constantsRef = constants;
717 llvm::Constant *result = buildSequentialConstant(
718 constants&: constantsRef, shape: elementsAttr.getShapedType().getShape(), type: llvmType, loc);
719 assert(constantsRef.empty() && "did not consume all elemental constants");
720 return result;
721 }
722
723 if (auto stringAttr = dyn_cast<StringAttr>(Val&: attr)) {
724 return llvm::ConstantDataArray::get(Context&: moduleTranslation.getLLVMContext(),
725 Elts: ArrayRef<char>{stringAttr.getValue()});
726 }
727
728 // Handle arrays of structs that cannot be represented as DenseElementsAttr
729 // in MLIR.
730 if (auto arrayAttr = dyn_cast<ArrayAttr>(Val&: attr)) {
731 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: llvmType)) {
732 llvm::Type *elementType = arrayTy->getElementType();
733 Attribute previousElementAttr;
734 llvm::Constant *elementCst = nullptr;
735 SmallVector<llvm::Constant *> constants;
736 constants.reserve(N: arrayTy->getNumElements());
737 for (Attribute elementAttr : arrayAttr) {
738 // Arrays with a single value or with repeating values are quite common.
739 // Short-circuit the translation when the element value is the same as
740 // the previous one.
741 if (!previousElementAttr || previousElementAttr != elementAttr) {
742 previousElementAttr = elementAttr;
743 elementCst =
744 getLLVMConstant(llvmType: elementType, attr: elementAttr, loc, moduleTranslation);
745 if (!elementCst)
746 return nullptr;
747 }
748 constants.push_back(Elt: elementCst);
749 }
750 return llvm::ConstantArray::get(T: arrayTy, V: constants);
751 }
752 }
753
754 emitError(loc, message: "unsupported constant value");
755 return nullptr;
756}
757
758ModuleTranslation::ModuleTranslation(Operation *module,
759 std::unique_ptr<llvm::Module> llvmModule)
760 : mlirModule(module), llvmModule(std::move(llvmModule)),
761 debugTranslation(
762 std::make_unique<DebugTranslation>(args&: module, args&: *this->llvmModule)),
763 loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
764 args&: *this, args&: *this->llvmModule)),
765 typeTranslator(this->llvmModule->getContext()),
766 iface(module->getContext()) {
767 assert(satisfiesLLVMModule(mlirModule) &&
768 "mlirModule should honor LLVM's module semantics.");
769}
770
771ModuleTranslation::~ModuleTranslation() {
772 if (ompBuilder && !ompBuilder->isFinalized())
773 ompBuilder->finalize();
774}
775
776void ModuleTranslation::forgetMapping(Region &region) {
777 SmallVector<Region *> toProcess;
778 toProcess.push_back(Elt: &region);
779 while (!toProcess.empty()) {
780 Region *current = toProcess.pop_back_val();
781 for (Block &block : *current) {
782 blockMapping.erase(Val: &block);
783 for (Value arg : block.getArguments())
784 valueMapping.erase(Val: arg);
785 for (Operation &op : block) {
786 for (Value value : op.getResults())
787 valueMapping.erase(Val: value);
788 if (op.hasSuccessors())
789 branchMapping.erase(Val: &op);
790 if (isa<LLVM::GlobalOp>(Val: op))
791 globalsMapping.erase(Val: &op);
792 if (isa<LLVM::AliasOp>(Val: op))
793 aliasesMapping.erase(Val: &op);
794 if (isa<LLVM::CallOp>(Val: op))
795 callMapping.erase(Val: &op);
796 llvm::append_range(
797 C&: toProcess,
798 R: llvm::map_range(C: op.getRegions(), F: [](Region &r) { return &r; }));
799 }
800 }
801 }
802}
803
804/// Get the SSA value passed to the current block from the terminator operation
805/// of its predecessor.
806static Value getPHISourceValue(Block *current, Block *pred,
807 unsigned numArguments, unsigned index) {
808 Operation &terminator = *pred->getTerminator();
809 if (isa<LLVM::BrOp>(Val: terminator))
810 return terminator.getOperand(idx: index);
811
812#ifndef NDEBUG
813 llvm::SmallPtrSet<Block *, 4> seenSuccessors;
814 for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
815 Block *successor = terminator.getSuccessor(i);
816 auto branch = cast<BranchOpInterface>(terminator);
817 SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
818 assert(
819 (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
820 "successors with arguments in LLVM branches must be different blocks");
821 seenSuccessors.insert(successor);
822 }
823#endif
824
825 // For instructions that branch based on a condition value, we need to take
826 // the operands for the branch that was taken.
827 if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(Val&: terminator)) {
828 // For conditional branches, we take the operands from either the "true" or
829 // the "false" branch.
830 return condBranchOp.getSuccessor(i: 0) == current
831 ? condBranchOp.getTrueDestOperands()[index]
832 : condBranchOp.getFalseDestOperands()[index];
833 }
834
835 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(Val&: terminator)) {
836 // For switches, we take the operands from either the default case, or from
837 // the case branch that was taken.
838 if (switchOp.getDefaultDestination() == current)
839 return switchOp.getDefaultOperands()[index];
840 for (const auto &i : llvm::enumerate(First: switchOp.getCaseDestinations()))
841 if (i.value() == current)
842 return switchOp.getCaseOperands(index: i.index())[index];
843 }
844
845 if (auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(Val&: terminator)) {
846 // For indirect branches we take operands for each successor.
847 for (const auto &i : llvm::enumerate(First: indBrOp->getSuccessors())) {
848 if (indBrOp->getSuccessor(index: i.index()) == current)
849 return indBrOp.getSuccessorOperands(index: i.index())[index];
850 }
851 }
852
853 if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(Val&: terminator)) {
854 return invokeOp.getNormalDest() == current
855 ? invokeOp.getNormalDestOperands()[index]
856 : invokeOp.getUnwindDestOperands()[index];
857 }
858
859 llvm_unreachable(
860 "only branch, switch or invoke operations can be terminators "
861 "of a block that has successors");
862}
863
864/// Connect the PHI nodes to the results of preceding blocks.
865void mlir::LLVM::detail::connectPHINodes(Region &region,
866 const ModuleTranslation &state) {
867 // Skip the first block, it cannot be branched to and its arguments correspond
868 // to the arguments of the LLVM function.
869 for (Block &bb : llvm::drop_begin(RangeOrContainer&: region)) {
870 llvm::BasicBlock *llvmBB = state.lookupBlock(block: &bb);
871 auto phis = llvmBB->phis();
872 auto numArguments = bb.getNumArguments();
873 assert(numArguments == std::distance(phis.begin(), phis.end()));
874 for (auto [index, phiNode] : llvm::enumerate(First&: phis)) {
875 for (auto *pred : bb.getPredecessors()) {
876 // Find the LLVM IR block that contains the converted terminator
877 // instruction and use it in the PHI node. Note that this block is not
878 // necessarily the same as state.lookupBlock(pred), some operations
879 // (in particular, OpenMP operations using OpenMPIRBuilder) may have
880 // split the blocks.
881 llvm::Instruction *terminator =
882 state.lookupBranch(op: pred->getTerminator());
883 assert(terminator && "missing the mapping for a terminator");
884 phiNode.addIncoming(V: state.lookupValue(value: getPHISourceValue(
885 current: &bb, pred, numArguments, index)),
886 BB: terminator->getParent());
887 }
888 }
889 }
890}
891
892llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
893 llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
894 ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
895 llvm::Module *module = builder.GetInsertBlock()->getModule();
896 llvm::Function *fn =
897 llvm::Intrinsic::getOrInsertDeclaration(M: module, id: intrinsic, Tys: tys);
898 return builder.CreateCall(Callee: fn, Args: args);
899}
900
901llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
902 llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
903 Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
904 ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
905 ArrayRef<unsigned> immArgPositions,
906 ArrayRef<StringLiteral> immArgAttrNames) {
907 assert(immArgPositions.size() == immArgAttrNames.size() &&
908 "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
909 "length");
910
911 SmallVector<llvm::OperandBundleDef> opBundles;
912 size_t numOpBundleOperands = 0;
913 auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
914 Val: intrOp->getAttr(name: LLVMDialect::getOpBundleSizesAttrName()));
915 auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
916 Val: intrOp->getAttr(name: LLVMDialect::getOpBundleTagsAttrName()));
917
918 if (opBundleSizesAttr && opBundleTagsAttr) {
919 ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
920 assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
921 "operand bundles and tags do not match");
922
923 numOpBundleOperands =
924 std::accumulate(first: opBundleSizes.begin(), last: opBundleSizes.end(), init: size_t(0));
925 assert(numOpBundleOperands <= intrOp->getNumOperands() &&
926 "operand bundle operands is more than the number of operands");
927
928 ValueRange operands = intrOp->getOperands().take_back(n: numOpBundleOperands);
929 size_t nextOperandIdx = 0;
930 opBundles.reserve(N: opBundleSizesAttr.size());
931
932 for (auto [opBundleTagAttr, bundleSize] :
933 llvm::zip(t&: opBundleTagsAttr, u&: opBundleSizes)) {
934 auto bundleTag = cast<StringAttr>(Val: opBundleTagAttr).str();
935 auto bundleOperands = moduleTranslation.lookupValues(
936 values: operands.slice(n: nextOperandIdx, m: bundleSize));
937 opBundles.emplace_back(Args: std::move(bundleTag), Args: std::move(bundleOperands));
938 nextOperandIdx += bundleSize;
939 }
940 }
941
942 // Map operands and attributes to LLVM values.
943 auto opOperands = intrOp->getOperands().drop_back(n: numOpBundleOperands);
944 auto operands = moduleTranslation.lookupValues(values: opOperands);
945 SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
946 for (auto [immArgPos, immArgName] :
947 llvm::zip(t&: immArgPositions, u&: immArgAttrNames)) {
948 auto attr = llvm::cast<TypedAttr>(Val: intrOp->getAttr(name: immArgName));
949 assert(attr.getType().isIntOrFloat() && "expected int or float immarg");
950 auto *type = moduleTranslation.convertType(type: attr.getType());
951 args[immArgPos] = LLVM::detail::getLLVMConstant(
952 llvmType: type, attr, loc: intrOp->getLoc(), moduleTranslation);
953 }
954 unsigned opArg = 0;
955 for (auto &arg : args) {
956 if (!arg)
957 arg = operands[opArg++];
958 }
959
960 // Resolve overloaded intrinsic declaration.
961 SmallVector<llvm::Type *> overloadedTypes;
962 for (unsigned overloadedResultIdx : overloadedResults) {
963 if (numResults > 1) {
964 // More than one result is mapped to an LLVM struct.
965 overloadedTypes.push_back(Elt: moduleTranslation.convertType(
966 type: llvm::cast<LLVM::LLVMStructType>(Val: intrOp->getResult(idx: 0).getType())
967 .getBody()[overloadedResultIdx]));
968 } else {
969 overloadedTypes.push_back(
970 Elt: moduleTranslation.convertType(type: intrOp->getResult(idx: 0).getType()));
971 }
972 }
973 for (unsigned overloadedOperandIdx : overloadedOperands)
974 overloadedTypes.push_back(Elt: args[overloadedOperandIdx]->getType());
975 llvm::Module *module = builder.GetInsertBlock()->getModule();
976 llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
977 M: module, id: intrinsic, Tys: overloadedTypes);
978
979 return builder.CreateCall(Callee: llvmIntr, Args: args, OpBundles: opBundles);
980}
981
982/// Given a single MLIR operation, create the corresponding LLVM IR operation
983/// using the `builder`.
984LogicalResult ModuleTranslation::convertOperation(Operation &op,
985 llvm::IRBuilderBase &builder,
986 bool recordInsertions) {
987 const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(obj: &op);
988 if (!opIface)
989 return op.emitError(message: "cannot be converted to LLVM IR: missing "
990 "`LLVMTranslationDialectInterface` registration for "
991 "dialect for op: ")
992 << op.getName();
993
994 InstructionCapturingInserter::CollectionScope scope(builder,
995 recordInsertions);
996 if (failed(Result: opIface->convertOperation(op: &op, builder, moduleTranslation&: *this)))
997 return op.emitError(message: "LLVM Translation failed for operation: ")
998 << op.getName();
999
1000 return convertDialectAttributes(op: &op, instructions: scope.getCapturedInstructions());
1001}
1002
1003/// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
1004/// to define values corresponding to the MLIR block arguments. These nodes
1005/// are not connected to the source basic blocks, which may not exist yet. Uses
1006/// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
1007/// been created for `bb` and included in the block mapping. Inserts new
1008/// instructions at the end of the block and leaves `builder` in a state
1009/// suitable for further insertion into the end of the block.
1010LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
1011 bool ignoreArguments,
1012 llvm::IRBuilderBase &builder,
1013 bool recordInsertions) {
1014 builder.SetInsertPoint(lookupBlock(block: &bb));
1015 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
1016
1017 // Before traversing operations, make block arguments available through
1018 // value remapping and PHI nodes, but do not add incoming edges for the PHI
1019 // nodes just yet: those values may be defined by this or following blocks.
1020 // This step is omitted if "ignoreArguments" is set. The arguments of the
1021 // first block have been already made available through the remapping of
1022 // LLVM function arguments.
1023 if (!ignoreArguments) {
1024 auto predecessors = bb.getPredecessors();
1025 unsigned numPredecessors =
1026 std::distance(first: predecessors.begin(), last: predecessors.end());
1027 for (auto arg : bb.getArguments()) {
1028 auto wrappedType = arg.getType();
1029 if (!isCompatibleType(type: wrappedType))
1030 return emitError(loc: bb.front().getLoc(),
1031 message: "block argument does not have an LLVM type");
1032 builder.SetCurrentDebugLocation(
1033 debugTranslation->translateLoc(loc: arg.getLoc(), scope: subprogram));
1034 llvm::Type *type = convertType(type: wrappedType);
1035 llvm::PHINode *phi = builder.CreatePHI(Ty: type, NumReservedValues: numPredecessors);
1036 mapValue(mlir: arg, llvm: phi);
1037 }
1038 }
1039
1040 // Traverse operations.
1041 for (auto &op : bb) {
1042 // Set the current debug location within the builder.
1043 builder.SetCurrentDebugLocation(
1044 debugTranslation->translateLoc(loc: op.getLoc(), scope: subprogram));
1045
1046 if (failed(Result: convertOperation(op, builder, recordInsertions)))
1047 return failure();
1048
1049 // Set the branch weight metadata on the translated instruction.
1050 if (auto iface = dyn_cast<WeightedBranchOpInterface>(Val&: op))
1051 setBranchWeightsMetadata(iface);
1052 }
1053
1054 return success();
1055}
1056
1057/// A helper method to get the single Block in an operation honoring LLVM's
1058/// module requirements.
1059static Block &getModuleBody(Operation *module) {
1060 return module->getRegion(index: 0).front();
1061}
1062
1063/// A helper method to decide if a constant must not be set as a global variable
1064/// initializer. For an external linkage variable, the variable with an
1065/// initializer is considered externally visible and defined in this module, the
1066/// variable without an initializer is externally available and is defined
1067/// elsewhere.
1068static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,
1069 llvm::Constant *cst) {
1070 return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
1071 linkage == llvm::GlobalVariable::ExternalWeakLinkage;
1072}
1073
1074/// Sets the runtime preemption specifier of `gv` to dso_local if
1075/// `dsoLocalRequested` is true, otherwise it is left unchanged.
1076static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
1077 llvm::GlobalValue *gv) {
1078 if (dsoLocalRequested)
1079 gv->setDSOLocal(true);
1080}
1081
1082LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
1083 // Mapping from compile unit to its respective set of global variables.
1084 DenseMap<llvm::DICompileUnit *, SmallVector<llvm::Metadata *>> allGVars;
1085
1086 // First, create all global variables and global aliases in LLVM IR. A global
1087 // or alias body may refer to another global/alias or itself, so all the
1088 // mapping needs to happen prior to body conversion.
1089
1090 // Create all llvm::GlobalVariable
1091 for (auto op : getModuleBody(module: mlirModule).getOps<LLVM::GlobalOp>()) {
1092 llvm::Type *type = convertType(type: op.getType());
1093 llvm::Constant *cst = nullptr;
1094 if (op.getValueOrNull()) {
1095 // String attributes are treated separately because they cannot appear as
1096 // in-function constants and are thus not supported by getLLVMConstant.
1097 if (auto strAttr = dyn_cast_or_null<StringAttr>(Val: op.getValueOrNull())) {
1098 cst = llvm::ConstantDataArray::getString(
1099 Context&: llvmModule->getContext(), Initializer: strAttr.getValue(), /*AddNull=*/false);
1100 type = cst->getType();
1101 } else if (!(cst = getLLVMConstant(llvmType: type, attr: op.getValueOrNull(), loc: op.getLoc(),
1102 moduleTranslation: *this))) {
1103 return failure();
1104 }
1105 }
1106
1107 auto linkage = convertLinkageToLLVM(value: op.getLinkage());
1108
1109 // LLVM IR requires constant with linkage other than external or weak
1110 // external to have initializers. If MLIR does not provide an initializer,
1111 // default to undef.
1112 bool dropInitializer = shouldDropGlobalInitializer(linkage, cst);
1113 if (!dropInitializer && !cst)
1114 cst = llvm::UndefValue::get(T: type);
1115 else if (dropInitializer && cst)
1116 cst = nullptr;
1117
1118 auto *var = new llvm::GlobalVariable(
1119 *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
1120 /*InsertBefore=*/nullptr,
1121 op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
1122 : llvm::GlobalValue::NotThreadLocal,
1123 op.getAddrSpace(), op.getExternallyInitialized());
1124
1125 if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
1126 auto selectorOp = cast<ComdatSelectorOp>(
1127 Val: SymbolTable::lookupNearestSymbolFrom(from: op, symbol: *comdat));
1128 var->setComdat(comdatMapping.lookup(Val: selectorOp));
1129 }
1130
1131 if (op.getUnnamedAddr().has_value())
1132 var->setUnnamedAddr(convertUnnamedAddrToLLVM(value: *op.getUnnamedAddr()));
1133
1134 if (op.getSection().has_value())
1135 var->setSection(*op.getSection());
1136
1137 addRuntimePreemptionSpecifier(dsoLocalRequested: op.getDsoLocal(), gv: var);
1138
1139 std::optional<uint64_t> alignment = op.getAlignment();
1140 if (alignment.has_value())
1141 var->setAlignment(llvm::MaybeAlign(alignment.value()));
1142
1143 var->setVisibility(convertVisibilityToLLVM(value: op.getVisibility_()));
1144
1145 globalsMapping.try_emplace(Key: op, Args&: var);
1146
1147 // Add debug information if present.
1148 if (op.getDbgExprs()) {
1149 for (auto exprAttr :
1150 op.getDbgExprs()->getAsRange<DIGlobalVariableExpressionAttr>()) {
1151 llvm::DIGlobalVariableExpression *diGlobalExpr =
1152 debugTranslation->translateGlobalVariableExpression(attr: exprAttr);
1153 llvm::DIGlobalVariable *diGlobalVar = diGlobalExpr->getVariable();
1154 var->addDebugInfo(GV: diGlobalExpr);
1155
1156 // There is no `globals` field in DICompileUnitAttr which can be
1157 // directly assigned to DICompileUnit. We have to build the list by
1158 // looking at the dbgExpr of all the GlobalOps. The scope of the
1159 // variable is used to get the DICompileUnit in which to add it. But
1160 // there are cases where the scope of a global does not directly point
1161 // to the DICompileUnit and we have to do a bit more work to get to
1162 // it. Some of those cases are:
1163 //
1164 // 1. For the languages that support modules, the scope hierarchy can
1165 // be variable -> DIModule -> DICompileUnit
1166 //
1167 // 2. For the Fortran common block variable, the scope hierarchy can
1168 // be variable -> DICommonBlock -> DISubprogram -> DICompileUnit
1169 //
1170 // 3. For entities like static local variables in C or variable with
1171 // SAVE attribute in Fortran, the scope hierarchy can be
1172 // variable -> DISubprogram -> DICompileUnit
1173 llvm::DIScope *scope = diGlobalVar->getScope();
1174 if (auto *mod = dyn_cast_if_present<llvm::DIModule>(Val: scope))
1175 scope = mod->getScope();
1176 else if (auto *cb = dyn_cast_if_present<llvm::DICommonBlock>(Val: scope)) {
1177 if (auto *sp =
1178 dyn_cast_if_present<llvm::DISubprogram>(Val: cb->getScope()))
1179 scope = sp->getUnit();
1180 } else if (auto *sp = dyn_cast_if_present<llvm::DISubprogram>(Val: scope))
1181 scope = sp->getUnit();
1182
1183 // Get the compile unit (scope) of the the global variable.
1184 if (llvm::DICompileUnit *compileUnit =
1185 dyn_cast_if_present<llvm::DICompileUnit>(Val: scope)) {
1186 // Update the compile unit with this incoming global variable
1187 // expression during the finalizing step later.
1188 allGVars[compileUnit].push_back(Elt: diGlobalExpr);
1189 }
1190 }
1191 }
1192 }
1193
1194 // Create all llvm::GlobalAlias
1195 for (auto op : getModuleBody(module: mlirModule).getOps<LLVM::AliasOp>()) {
1196 llvm::Type *type = convertType(type: op.getType());
1197 llvm::Constant *cst = nullptr;
1198 llvm::GlobalValue::LinkageTypes linkage =
1199 convertLinkageToLLVM(value: op.getLinkage());
1200 llvm::Module &llvmMod = *llvmModule;
1201
1202 // Note address space and aliasee info isn't set just yet.
1203 llvm::GlobalAlias *var = llvm::GlobalAlias::create(
1204 Ty: type, AddressSpace: op.getAddrSpace(), Linkage: linkage, Name: op.getSymName(), /*placeholder*/ Aliasee: cst,
1205 Parent: &llvmMod);
1206
1207 var->setThreadLocalMode(op.getThreadLocal_()
1208 ? llvm::GlobalAlias::GeneralDynamicTLSModel
1209 : llvm::GlobalAlias::NotThreadLocal);
1210
1211 // Note there is no need to setup the comdat because GlobalAlias calls into
1212 // the aliasee comdat information automatically.
1213
1214 if (op.getUnnamedAddr().has_value())
1215 var->setUnnamedAddr(convertUnnamedAddrToLLVM(value: *op.getUnnamedAddr()));
1216
1217 var->setVisibility(convertVisibilityToLLVM(value: op.getVisibility_()));
1218
1219 aliasesMapping.try_emplace(Key: op, Args&: var);
1220 }
1221
1222 // Convert global variable bodies.
1223 for (auto op : getModuleBody(module: mlirModule).getOps<LLVM::GlobalOp>()) {
1224 if (Block *initializer = op.getInitializerBlock()) {
1225 llvm::IRBuilder<llvm::TargetFolder> builder(
1226 llvmModule->getContext(),
1227 llvm::TargetFolder(llvmModule->getDataLayout()));
1228
1229 [[maybe_unused]] int numConstantsHit = 0;
1230 [[maybe_unused]] int numConstantsErased = 0;
1231 DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;
1232
1233 for (auto &op : initializer->without_terminator()) {
1234 if (failed(Result: convertOperation(op, builder)))
1235 return emitError(loc: op.getLoc(), message: "fail to convert global initializer");
1236 auto *cst = dyn_cast<llvm::Constant>(Val: lookupValue(value: op.getResult(idx: 0)));
1237 if (!cst)
1238 return emitError(loc: op.getLoc(), message: "unemittable constant value");
1239
1240 // When emitting an LLVM constant, a new constant is created and the old
1241 // constant may become dangling and take space. We should remove the
1242 // dangling constants to avoid memory explosion especially for constant
1243 // arrays whose number of elements is large.
1244 // Because multiple operations may refer to the same constant, we need
1245 // to count the number of uses of each constant array and remove it only
1246 // when the count becomes zero.
1247 if (auto *agg = dyn_cast<llvm::ConstantAggregate>(Val: cst)) {
1248 numConstantsHit++;
1249 Value result = op.getResult(idx: 0);
1250 int numUsers = std::distance(first: result.use_begin(), last: result.use_end());
1251 auto [iterator, inserted] =
1252 constantAggregateUseMap.try_emplace(Key: agg, Args&: numUsers);
1253 if (!inserted) {
1254 // Key already exists, update the value
1255 iterator->second += numUsers;
1256 }
1257 }
1258 // Scan the operands of the operation to decrement the use count of
1259 // constants. Erase the constant if the use count becomes zero.
1260 for (Value v : op.getOperands()) {
1261 auto cst = dyn_cast<llvm::ConstantAggregate>(Val: lookupValue(value: v));
1262 if (!cst)
1263 continue;
1264 auto iter = constantAggregateUseMap.find(Val: cst);
1265 assert(iter != constantAggregateUseMap.end() && "constant not found");
1266 iter->second--;
1267 if (iter->second == 0) {
1268 // NOTE: cannot call removeDeadConstantUsers() here because it
1269 // may remove the constant which has uses not be converted yet.
1270 if (cst->user_empty()) {
1271 cst->destroyConstant();
1272 numConstantsErased++;
1273 }
1274 constantAggregateUseMap.erase(I: iter);
1275 }
1276 }
1277 }
1278
1279 ReturnOp ret = cast<ReturnOp>(Val: initializer->getTerminator());
1280 llvm::Constant *cst =
1281 cast<llvm::Constant>(Val: lookupValue(value: ret.getOperand(i: 0)));
1282 auto *global = cast<llvm::GlobalVariable>(Val: lookupGlobal(op));
1283 if (!shouldDropGlobalInitializer(linkage: global->getLinkage(), cst))
1284 global->setInitializer(cst);
1285
1286 // Try to remove the dangling constants again after all operations are
1287 // converted.
1288 for (auto it : constantAggregateUseMap) {
1289 auto cst = it.first;
1290 cst->removeDeadConstantUsers();
1291 if (cst->user_empty()) {
1292 cst->destroyConstant();
1293 numConstantsErased++;
1294 }
1295 }
1296
1297 LLVM_DEBUG(llvm::dbgs()
1298 << "Convert initializer for " << op.getName() << "\n";
1299 llvm::dbgs() << numConstantsHit << " new constants hit\n";
1300 llvm::dbgs()
1301 << numConstantsErased << " dangling constants erased\n";);
1302 }
1303 }
1304
1305 // Convert llvm.mlir.global_ctors and dtors.
1306 for (Operation &op : getModuleBody(module: mlirModule)) {
1307 auto ctorOp = dyn_cast<GlobalCtorsOp>(Val&: op);
1308 auto dtorOp = dyn_cast<GlobalDtorsOp>(Val&: op);
1309 if (!ctorOp && !dtorOp)
1310 continue;
1311
1312 // The empty / zero initialized version of llvm.global_(c|d)tors cannot be
1313 // handled by appendGlobalFn logic below, which just ignores empty (c|d)tor
1314 // lists. Make sure it gets emitted.
1315 if ((ctorOp && ctorOp.getCtors().empty()) ||
1316 (dtorOp && dtorOp.getDtors().empty())) {
1317 llvm::IRBuilder<llvm::TargetFolder> builder(
1318 llvmModule->getContext(),
1319 llvm::TargetFolder(llvmModule->getDataLayout()));
1320 llvm::Type *eltTy = llvm::StructType::get(
1321 elt1: builder.getInt32Ty(), elts: builder.getPtrTy(), elts: builder.getPtrTy());
1322 llvm::ArrayType *at = llvm::ArrayType::get(ElementType: eltTy, NumElements: 0);
1323 llvm::Constant *zeroInit = llvm::Constant::getNullValue(Ty: at);
1324 (void)new llvm::GlobalVariable(
1325 *llvmModule, zeroInit->getType(), false,
1326 llvm::GlobalValue::AppendingLinkage, zeroInit,
1327 ctorOp ? "llvm.global_ctors" : "llvm.global_dtors");
1328 } else {
1329 auto range = ctorOp
1330 ? llvm::zip(t: ctorOp.getCtors(), u: ctorOp.getPriorities())
1331 : llvm::zip(t: dtorOp.getDtors(), u: dtorOp.getPriorities());
1332 auto appendGlobalFn =
1333 ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
1334 for (const auto &[sym, prio] : range) {
1335 llvm::Function *f =
1336 lookupFunction(name: cast<FlatSymbolRefAttr>(Val: sym).getValue());
1337 appendGlobalFn(*llvmModule, f, cast<IntegerAttr>(Val: prio).getInt(),
1338 /*Data=*/nullptr);
1339 }
1340 }
1341 }
1342
1343 for (auto op : getModuleBody(module: mlirModule).getOps<LLVM::GlobalOp>())
1344 if (failed(Result: convertDialectAttributes(op, instructions: {})))
1345 return failure();
1346
1347 // Finally, update the compile units their respective sets of global variables
1348 // created earlier.
1349 for (const auto &[compileUnit, globals] : allGVars) {
1350 compileUnit->replaceGlobalVariables(
1351 N: llvm::MDTuple::get(Context&: getLLVMContext(), MDs: globals));
1352 }
1353
1354 // Convert global alias bodies.
1355 for (auto op : getModuleBody(module: mlirModule).getOps<LLVM::AliasOp>()) {
1356 Block &initializer = op.getInitializerBlock();
1357 llvm::IRBuilder<llvm::TargetFolder> builder(
1358 llvmModule->getContext(),
1359 llvm::TargetFolder(llvmModule->getDataLayout()));
1360
1361 for (mlir::Operation &op : initializer.without_terminator()) {
1362 if (failed(Result: convertOperation(op, builder)))
1363 return emitError(loc: op.getLoc(), message: "fail to convert alias initializer");
1364 if (!isa<llvm::Constant>(Val: lookupValue(value: op.getResult(idx: 0))))
1365 return emitError(loc: op.getLoc(), message: "unemittable constant value");
1366 }
1367
1368 auto ret = cast<ReturnOp>(Val: initializer.getTerminator());
1369 auto *cst = cast<llvm::Constant>(Val: lookupValue(value: ret.getOperand(i: 0)));
1370 assert(aliasesMapping.count(op));
1371 auto *alias = cast<llvm::GlobalAlias>(Val: aliasesMapping[op]);
1372 alias->setAliasee(cst);
1373 }
1374
1375 for (auto op : getModuleBody(module: mlirModule).getOps<LLVM::AliasOp>())
1376 if (failed(Result: convertDialectAttributes(op, instructions: {})))
1377 return failure();
1378
1379 return success();
1380}
1381
1382/// Attempts to add an attribute identified by `key`, optionally with the given
1383/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
1384/// attribute has a kind known to LLVM IR, create the attribute of this kind,
1385/// otherwise keep it as a string attribute. Performs additional checks for
1386/// attributes known to have or not have a value in order to avoid assertions
1387/// inside LLVM upon construction.
1388static LogicalResult checkedAddLLVMFnAttribute(Location loc,
1389 llvm::Function *llvmFunc,
1390 StringRef key,
1391 StringRef value = StringRef()) {
1392 auto kind = llvm::Attribute::getAttrKindFromName(AttrName: key);
1393 if (kind == llvm::Attribute::None) {
1394 llvmFunc->addFnAttr(Kind: key, Val: value);
1395 return success();
1396 }
1397
1398 if (llvm::Attribute::isIntAttrKind(Kind: kind)) {
1399 if (value.empty())
1400 return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
1401
1402 int64_t result;
1403 if (!value.getAsInteger(/*Radix=*/0, Result&: result))
1404 llvmFunc->addFnAttr(
1405 Attr: llvm::Attribute::get(Context&: llvmFunc->getContext(), Kind: kind, Val: result));
1406 else
1407 llvmFunc->addFnAttr(Kind: key, Val: value);
1408 return success();
1409 }
1410
1411 if (!value.empty())
1412 return emitError(loc) << "LLVM attribute '" << key
1413 << "' does not expect a value, found '" << value
1414 << "'";
1415
1416 llvmFunc->addFnAttr(Kind: kind);
1417 return success();
1418}
1419
1420/// Return a representation of `value` as metadata.
1421static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
1422 const llvm::APInt &value) {
1423 llvm::Constant *constant = llvm::ConstantInt::get(Context&: context, V: value);
1424 return llvm::ConstantAsMetadata::get(C: constant);
1425}
1426
1427/// Return a representation of `value` as an MDNode.
1428static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
1429 const llvm::APInt &value) {
1430 return llvm::MDNode::get(Context&: context, MDs: convertIntegerToMetadata(context, value));
1431}
1432
1433/// Return an MDNode encoding `vec_type_hint` metadata.
1434static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
1435 llvm::Type *type,
1436 bool isSigned) {
1437 llvm::Metadata *typeMD =
1438 llvm::ConstantAsMetadata::get(C: llvm::UndefValue::get(T: type));
1439 llvm::Metadata *isSignedMD =
1440 convertIntegerToMetadata(context, value: llvm::APInt(32, isSigned ? 1 : 0));
1441 return llvm::MDNode::get(Context&: context, MDs: {typeMD, isSignedMD});
1442}
1443
1444/// Return an MDNode with a tuple given by the values in `values`.
1445static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
1446 ArrayRef<int32_t> values) {
1447 SmallVector<llvm::Metadata *> mdValues;
1448 llvm::transform(
1449 Range&: values, d_first: std::back_inserter(x&: mdValues), F: [&context](int32_t value) {
1450 return convertIntegerToMetadata(context, value: llvm::APInt(32, value));
1451 });
1452 return llvm::MDNode::get(Context&: context, MDs: mdValues);
1453}
1454
1455/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
1456/// Reports error to `loc` if any and returns immediately. Expects `attributes`
1457/// to be an array attribute containing either string attributes, treated as
1458/// value-less LLVM attributes, or array attributes containing two string
1459/// attributes, with the first string being the name of the corresponding LLVM
1460/// attribute and the second string beings its value. Note that even integer
1461/// attributes are expected to have their values expressed as strings.
1462static LogicalResult
1463forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
1464 llvm::Function *llvmFunc) {
1465 if (!attributes)
1466 return success();
1467
1468 for (Attribute attr : *attributes) {
1469 if (auto stringAttr = dyn_cast<StringAttr>(Val&: attr)) {
1470 if (failed(
1471 Result: checkedAddLLVMFnAttribute(loc, llvmFunc, key: stringAttr.getValue())))
1472 return failure();
1473 continue;
1474 }
1475
1476 auto arrayAttr = dyn_cast<ArrayAttr>(Val&: attr);
1477 if (!arrayAttr || arrayAttr.size() != 2)
1478 return emitError(loc)
1479 << "expected 'passthrough' to contain string or array attributes";
1480
1481 auto keyAttr = dyn_cast<StringAttr>(Val: arrayAttr[0]);
1482 auto valueAttr = dyn_cast<StringAttr>(Val: arrayAttr[1]);
1483 if (!keyAttr || !valueAttr)
1484 return emitError(loc)
1485 << "expected arrays within 'passthrough' to contain two strings";
1486
1487 if (failed(Result: checkedAddLLVMFnAttribute(loc, llvmFunc, key: keyAttr.getValue(),
1488 value: valueAttr.getValue())))
1489 return failure();
1490 }
1491 return success();
1492}
1493
1494LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
1495 // Clear the block, branch value mappings, they are only relevant within one
1496 // function.
1497 blockMapping.clear();
1498 valueMapping.clear();
1499 branchMapping.clear();
1500 llvm::Function *llvmFunc = lookupFunction(name: func.getName());
1501
1502 // Add function arguments to the value remapping table.
1503 for (auto [mlirArg, llvmArg] :
1504 llvm::zip(t: func.getArguments(), u: llvmFunc->args()))
1505 mapValue(mlir: mlirArg, llvm: &llvmArg);
1506
1507 // Check the personality and set it.
1508 if (func.getPersonality()) {
1509 llvm::Type *ty = llvm::PointerType::getUnqual(C&: llvmFunc->getContext());
1510 if (llvm::Constant *pfunc = getLLVMConstant(llvmType: ty, attr: func.getPersonalityAttr(),
1511 loc: func.getLoc(), moduleTranslation: *this))
1512 llvmFunc->setPersonalityFn(pfunc);
1513 }
1514
1515 if (std::optional<StringRef> section = func.getSection())
1516 llvmFunc->setSection(*section);
1517
1518 if (func.getArmStreaming())
1519 llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_enabled");
1520 else if (func.getArmLocallyStreaming())
1521 llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_body");
1522 else if (func.getArmStreamingCompatible())
1523 llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_compatible");
1524
1525 if (func.getArmNewZa())
1526 llvmFunc->addFnAttr(Kind: "aarch64_new_za");
1527 else if (func.getArmInZa())
1528 llvmFunc->addFnAttr(Kind: "aarch64_in_za");
1529 else if (func.getArmOutZa())
1530 llvmFunc->addFnAttr(Kind: "aarch64_out_za");
1531 else if (func.getArmInoutZa())
1532 llvmFunc->addFnAttr(Kind: "aarch64_inout_za");
1533 else if (func.getArmPreservesZa())
1534 llvmFunc->addFnAttr(Kind: "aarch64_preserves_za");
1535
1536 if (auto targetCpu = func.getTargetCpu())
1537 llvmFunc->addFnAttr(Kind: "target-cpu", Val: *targetCpu);
1538
1539 if (auto tuneCpu = func.getTuneCpu())
1540 llvmFunc->addFnAttr(Kind: "tune-cpu", Val: *tuneCpu);
1541
1542 if (auto reciprocalEstimates = func.getReciprocalEstimates())
1543 llvmFunc->addFnAttr(Kind: "reciprocal-estimates", Val: *reciprocalEstimates);
1544
1545 if (auto preferVectorWidth = func.getPreferVectorWidth())
1546 llvmFunc->addFnAttr(Kind: "prefer-vector-width", Val: *preferVectorWidth);
1547
1548 if (auto attr = func.getVscaleRange())
1549 llvmFunc->addFnAttr(Attr: llvm::Attribute::getWithVScaleRangeArgs(
1550 Context&: getLLVMContext(), MinValue: attr->getMinRange().getInt(),
1551 MaxValue: attr->getMaxRange().getInt()));
1552
1553 if (auto unsafeFpMath = func.getUnsafeFpMath())
1554 llvmFunc->addFnAttr(Kind: "unsafe-fp-math", Val: llvm::toStringRef(B: *unsafeFpMath));
1555
1556 if (auto noInfsFpMath = func.getNoInfsFpMath())
1557 llvmFunc->addFnAttr(Kind: "no-infs-fp-math", Val: llvm::toStringRef(B: *noInfsFpMath));
1558
1559 if (auto noNansFpMath = func.getNoNansFpMath())
1560 llvmFunc->addFnAttr(Kind: "no-nans-fp-math", Val: llvm::toStringRef(B: *noNansFpMath));
1561
1562 if (auto approxFuncFpMath = func.getApproxFuncFpMath())
1563 llvmFunc->addFnAttr(Kind: "approx-func-fp-math",
1564 Val: llvm::toStringRef(B: *approxFuncFpMath));
1565
1566 if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath())
1567 llvmFunc->addFnAttr(Kind: "no-signed-zeros-fp-math",
1568 Val: llvm::toStringRef(B: *noSignedZerosFpMath));
1569
1570 if (auto denormalFpMath = func.getDenormalFpMath())
1571 llvmFunc->addFnAttr(Kind: "denormal-fp-math", Val: *denormalFpMath);
1572
1573 if (auto denormalFpMathF32 = func.getDenormalFpMathF32())
1574 llvmFunc->addFnAttr(Kind: "denormal-fp-math-f32", Val: *denormalFpMathF32);
1575
1576 if (auto fpContract = func.getFpContract())
1577 llvmFunc->addFnAttr(Kind: "fp-contract", Val: *fpContract);
1578
1579 if (auto instrumentFunctionEntry = func.getInstrumentFunctionEntry())
1580 llvmFunc->addFnAttr(Kind: "instrument-function-entry", Val: *instrumentFunctionEntry);
1581
1582 if (auto instrumentFunctionExit = func.getInstrumentFunctionExit())
1583 llvmFunc->addFnAttr(Kind: "instrument-function-exit", Val: *instrumentFunctionExit);
1584
1585 // First, create all blocks so we can jump to them.
1586 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1587 for (auto &bb : func) {
1588 auto *llvmBB = llvm::BasicBlock::Create(Context&: llvmContext);
1589 llvmBB->insertInto(Parent: llvmFunc);
1590 mapBlock(mlir: &bb, llvm: llvmBB);
1591 }
1592
1593 // Then, convert blocks one by one in topological order to ensure defs are
1594 // converted before uses.
1595 auto blocks = getBlocksSortedByDominance(region&: func.getBody());
1596 for (Block *bb : blocks) {
1597 CapturingIRBuilder builder(llvmContext,
1598 llvm::TargetFolder(llvmModule->getDataLayout()));
1599 if (failed(Result: convertBlockImpl(bb&: *bb, ignoreArguments: bb->isEntryBlock(), builder,
1600 /*recordInsertions=*/true)))
1601 return failure();
1602 }
1603
1604 // After all blocks have been traversed and values mapped, connect the PHI
1605 // nodes to the results of preceding blocks.
1606 detail::connectPHINodes(region&: func.getBody(), state: *this);
1607
1608 // Finally, convert dialect attributes attached to the function.
1609 return convertDialectAttributes(op: func, instructions: {});
1610}
1611
1612LogicalResult ModuleTranslation::convertDialectAttributes(
1613 Operation *op, ArrayRef<llvm::Instruction *> instructions) {
1614 for (NamedAttribute attribute : op->getDialectAttrs())
1615 if (failed(Result: iface.amendOperation(op, instructions, attribute, moduleTranslation&: *this)))
1616 return failure();
1617 return success();
1618}
1619
1620/// Converts memory effect attributes from `func` and attaches them to
1621/// `llvmFunc`.
1622static void convertFunctionMemoryAttributes(LLVMFuncOp func,
1623 llvm::Function *llvmFunc) {
1624 if (!func.getMemoryEffects())
1625 return;
1626
1627 MemoryEffectsAttr memEffects = func.getMemoryEffectsAttr();
1628
1629 // Add memory effects incrementally.
1630 llvm::MemoryEffects newMemEffects =
1631 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
1632 convertModRefInfoToLLVM(value: memEffects.getArgMem()));
1633 newMemEffects |= llvm::MemoryEffects(
1634 llvm::MemoryEffects::Location::InaccessibleMem,
1635 convertModRefInfoToLLVM(value: memEffects.getInaccessibleMem()));
1636 newMemEffects |=
1637 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
1638 convertModRefInfoToLLVM(value: memEffects.getOther()));
1639 llvmFunc->setMemoryEffects(newMemEffects);
1640}
1641
1642/// Converts function attributes from `func` and attaches them to `llvmFunc`.
1643static void convertFunctionAttributes(LLVMFuncOp func,
1644 llvm::Function *llvmFunc) {
1645 if (func.getNoInlineAttr())
1646 llvmFunc->addFnAttr(Kind: llvm::Attribute::NoInline);
1647 if (func.getAlwaysInlineAttr())
1648 llvmFunc->addFnAttr(Kind: llvm::Attribute::AlwaysInline);
1649 if (func.getOptimizeNoneAttr())
1650 llvmFunc->addFnAttr(Kind: llvm::Attribute::OptimizeNone);
1651 if (func.getConvergentAttr())
1652 llvmFunc->addFnAttr(Kind: llvm::Attribute::Convergent);
1653 if (func.getNoUnwindAttr())
1654 llvmFunc->addFnAttr(Kind: llvm::Attribute::NoUnwind);
1655 if (func.getWillReturnAttr())
1656 llvmFunc->addFnAttr(Kind: llvm::Attribute::WillReturn);
1657 if (TargetFeaturesAttr targetFeatAttr = func.getTargetFeaturesAttr())
1658 llvmFunc->addFnAttr(Kind: "target-features", Val: targetFeatAttr.getFeaturesString());
1659 if (FramePointerKindAttr fpAttr = func.getFramePointerAttr())
1660 llvmFunc->addFnAttr(Kind: "frame-pointer", Val: stringifyFramePointerKind(
1661 fpAttr.getFramePointerKind()));
1662 if (UWTableKindAttr uwTableKindAttr = func.getUwtableKindAttr())
1663 llvmFunc->setUWTableKind(
1664 convertUWTableKindToLLVM(value: uwTableKindAttr.getUwtableKind()));
1665 convertFunctionMemoryAttributes(func, llvmFunc);
1666}
1667
1668/// Converts function attributes from `func` and attaches them to `llvmFunc`.
1669static void convertFunctionKernelAttributes(LLVMFuncOp func,
1670 llvm::Function *llvmFunc,
1671 ModuleTranslation &translation) {
1672 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1673
1674 if (VecTypeHintAttr vecTypeHint = func.getVecTypeHintAttr()) {
1675 Type type = vecTypeHint.getHint().getValue();
1676 llvm::Type *llvmType = translation.convertType(type);
1677 bool isSigned = vecTypeHint.getIsSigned();
1678 llvmFunc->setMetadata(
1679 Kind: func.getVecTypeHintAttrName(),
1680 Node: convertVecTypeHintToMDNode(context&: llvmContext, type: llvmType, isSigned));
1681 }
1682
1683 if (std::optional<ArrayRef<int32_t>> workGroupSizeHint =
1684 func.getWorkGroupSizeHint()) {
1685 llvmFunc->setMetadata(
1686 Kind: func.getWorkGroupSizeHintAttrName(),
1687 Node: convertIntegerArrayToMDNode(context&: llvmContext, values: *workGroupSizeHint));
1688 }
1689
1690 if (std::optional<ArrayRef<int32_t>> reqdWorkGroupSize =
1691 func.getReqdWorkGroupSize()) {
1692 llvmFunc->setMetadata(
1693 Kind: func.getReqdWorkGroupSizeAttrName(),
1694 Node: convertIntegerArrayToMDNode(context&: llvmContext, values: *reqdWorkGroupSize));
1695 }
1696
1697 if (std::optional<uint32_t> intelReqdSubGroupSize =
1698 func.getIntelReqdSubGroupSize()) {
1699 llvmFunc->setMetadata(
1700 Kind: func.getIntelReqdSubGroupSizeAttrName(),
1701 Node: convertIntegerToMDNode(context&: llvmContext,
1702 value: llvm::APInt(32, *intelReqdSubGroupSize)));
1703 }
1704}
1705
1706static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder,
1707 llvm::Attribute::AttrKind llvmKind,
1708 NamedAttribute namedAttr,
1709 ModuleTranslation &moduleTranslation,
1710 Location loc) {
1711 return llvm::TypeSwitch<Attribute, LogicalResult>(namedAttr.getValue())
1712 .Case<TypeAttr>(caseFn: [&](auto typeAttr) {
1713 attrBuilder.addTypeAttr(
1714 Kind: llvmKind, Ty: moduleTranslation.convertType(type: typeAttr.getValue()));
1715 return success();
1716 })
1717 .Case<IntegerAttr>(caseFn: [&](auto intAttr) {
1718 attrBuilder.addRawIntAttr(Kind: llvmKind, Value: intAttr.getInt());
1719 return success();
1720 })
1721 .Case<UnitAttr>(caseFn: [&](auto) {
1722 attrBuilder.addAttribute(Val: llvmKind);
1723 return success();
1724 })
1725 .Case<LLVM::ConstantRangeAttr>(caseFn: [&](auto rangeAttr) {
1726 attrBuilder.addConstantRangeAttr(
1727 Kind: llvmKind,
1728 CR: llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));
1729 return success();
1730 })
1731 .Default(defaultFn: [loc](auto) {
1732 return emitError(loc, message: "unsupported parameter attribute type");
1733 });
1734}
1735
1736FailureOr<llvm::AttrBuilder>
1737ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
1738 DictionaryAttr paramAttrs) {
1739 llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1740 auto attrNameToKindMapping = getAttrNameToKindMapping();
1741 Location loc = func.getLoc();
1742
1743 for (auto namedAttr : paramAttrs) {
1744 auto it = attrNameToKindMapping.find(Val: namedAttr.getName());
1745 if (it != attrNameToKindMapping.end()) {
1746 llvm::Attribute::AttrKind llvmKind = it->second;
1747 if (failed(Result: convertParameterAttr(attrBuilder, llvmKind, namedAttr, moduleTranslation&: *this,
1748 loc)))
1749 return failure();
1750 } else if (namedAttr.getNameDialect()) {
1751 if (failed(Result: iface.convertParameterAttr(function: func, argIdx, attribute: namedAttr, moduleTranslation&: *this)))
1752 return failure();
1753 }
1754 }
1755
1756 return attrBuilder;
1757}
1758
1759FailureOr<llvm::AttrBuilder>
1760ModuleTranslation::convertParameterAttrs(Location loc,
1761 DictionaryAttr paramAttrs) {
1762 llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1763 auto attrNameToKindMapping = getAttrNameToKindMapping();
1764
1765 for (auto namedAttr : paramAttrs) {
1766 auto it = attrNameToKindMapping.find(Val: namedAttr.getName());
1767 if (it != attrNameToKindMapping.end()) {
1768 llvm::Attribute::AttrKind llvmKind = it->second;
1769 if (failed(Result: convertParameterAttr(attrBuilder, llvmKind, namedAttr, moduleTranslation&: *this,
1770 loc)))
1771 return failure();
1772 }
1773 }
1774
1775 return attrBuilder;
1776}
1777
1778LogicalResult ModuleTranslation::convertFunctionSignatures() {
1779 // Declare all functions first because there may be function calls that form a
1780 // call graph with cycles, or global initializers that reference functions.
1781 for (auto function : getModuleBody(module: mlirModule).getOps<LLVMFuncOp>()) {
1782 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
1783 Name: function.getName(),
1784 T: cast<llvm::FunctionType>(Val: convertType(type: function.getFunctionType())));
1785 llvm::Function *llvmFunc = cast<llvm::Function>(Val: llvmFuncCst.getCallee());
1786 llvmFunc->setLinkage(convertLinkageToLLVM(value: function.getLinkage()));
1787 llvmFunc->setCallingConv(convertCConvToLLVM(value: function.getCConv()));
1788 mapFunction(name: function.getName(), func: llvmFunc);
1789 addRuntimePreemptionSpecifier(dsoLocalRequested: function.getDsoLocal(), gv: llvmFunc);
1790
1791 // Convert function attributes.
1792 convertFunctionAttributes(func: function, llvmFunc);
1793
1794 // Convert function kernel attributes to metadata.
1795 convertFunctionKernelAttributes(func: function, llvmFunc, translation&: *this);
1796
1797 // Convert function_entry_count attribute to metadata.
1798 if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
1799 llvmFunc->setEntryCount(Count: entryCount.value());
1800
1801 // Convert result attributes.
1802 if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
1803 DictionaryAttr resultAttrs = cast<DictionaryAttr>(Val: allResultAttrs[0]);
1804 FailureOr<llvm::AttrBuilder> attrBuilder =
1805 convertParameterAttrs(func: function, argIdx: -1, paramAttrs: resultAttrs);
1806 if (failed(Result: attrBuilder))
1807 return failure();
1808 llvmFunc->addRetAttrs(Attrs: *attrBuilder);
1809 }
1810
1811 // Convert argument attributes.
1812 for (auto [argIdx, llvmArg] : llvm::enumerate(First: llvmFunc->args())) {
1813 if (DictionaryAttr argAttrs = function.getArgAttrDict(index: argIdx)) {
1814 FailureOr<llvm::AttrBuilder> attrBuilder =
1815 convertParameterAttrs(func: function, argIdx, paramAttrs: argAttrs);
1816 if (failed(Result: attrBuilder))
1817 return failure();
1818 llvmArg.addAttrs(B&: *attrBuilder);
1819 }
1820 }
1821
1822 // Forward the pass-through attributes to LLVM.
1823 if (failed(Result: forwardPassthroughAttributes(
1824 loc: function.getLoc(), attributes: function.getPassthrough(), llvmFunc)))
1825 return failure();
1826
1827 // Convert visibility attribute.
1828 llvmFunc->setVisibility(convertVisibilityToLLVM(value: function.getVisibility_()));
1829
1830 // Convert the comdat attribute.
1831 if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) {
1832 auto selectorOp = cast<ComdatSelectorOp>(
1833 Val: SymbolTable::lookupNearestSymbolFrom(from: function, symbol: *comdat));
1834 llvmFunc->setComdat(comdatMapping.lookup(Val: selectorOp));
1835 }
1836
1837 if (auto gc = function.getGarbageCollector())
1838 llvmFunc->setGC(gc->str());
1839
1840 if (auto unnamedAddr = function.getUnnamedAddr())
1841 llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(value: *unnamedAddr));
1842
1843 if (auto alignment = function.getAlignment())
1844 llvmFunc->setAlignment(llvm::MaybeAlign(*alignment));
1845
1846 // Translate the debug information for this function.
1847 debugTranslation->translate(func: function, llvmFunc&: *llvmFunc);
1848 }
1849
1850 return success();
1851}
1852
1853LogicalResult ModuleTranslation::convertFunctions() {
1854 // Convert functions.
1855 for (auto function : getModuleBody(module: mlirModule).getOps<LLVMFuncOp>()) {
1856 // Do not convert external functions, but do process dialect attributes
1857 // attached to them.
1858 if (function.isExternal()) {
1859 if (failed(Result: convertDialectAttributes(op: function, instructions: {})))
1860 return failure();
1861 continue;
1862 }
1863
1864 if (failed(Result: convertOneFunction(func: function)))
1865 return failure();
1866 }
1867
1868 return success();
1869}
1870
1871LogicalResult ModuleTranslation::convertComdats() {
1872 for (auto comdatOp : getModuleBody(module: mlirModule).getOps<ComdatOp>()) {
1873 for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
1874 llvm::Module *module = getLLVMModule();
1875 if (module->getComdatSymbolTable().contains(Key: selectorOp.getSymName()))
1876 return emitError(loc: selectorOp.getLoc())
1877 << "comdat selection symbols must be unique even in different "
1878 "comdat regions";
1879 llvm::Comdat *comdat = module->getOrInsertComdat(Name: selectorOp.getSymName());
1880 comdat->setSelectionKind(convertComdatToLLVM(value: selectorOp.getComdat()));
1881 comdatMapping.try_emplace(Key: selectorOp, Args&: comdat);
1882 }
1883 }
1884 return success();
1885}
1886
1887LogicalResult ModuleTranslation::convertUnresolvedBlockAddress() {
1888 for (auto &[blockAddressOp, llvmCst] : unresolvedBlockAddressMapping) {
1889 BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
1890 llvm::BasicBlock *llvmBlock = lookupBlockAddress(attr: blockAddressAttr);
1891 assert(llvmBlock && "expected LLVM blocks to be already translated");
1892
1893 // Update mapping with new block address constant.
1894 auto *llvmBlockAddr = llvm::BlockAddress::get(
1895 F: lookupFunction(name: blockAddressAttr.getFunction().getValue()), BB: llvmBlock);
1896 llvmCst->replaceAllUsesWith(V: llvmBlockAddr);
1897 assert(llvmCst->use_empty() && "expected all uses to be replaced");
1898 cast<llvm::GlobalVariable>(Val: llvmCst)->eraseFromParent();
1899 }
1900 unresolvedBlockAddressMapping.clear();
1901 return success();
1902}
1903
1904void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op,
1905 llvm::Instruction *inst) {
1906 if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
1907 inst->setMetadata(KindID: llvm::LLVMContext::MD_access_group, Node: node);
1908}
1909
1910llvm::MDNode *
1911ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) {
1912 auto [scopeIt, scopeInserted] =
1913 aliasScopeMetadataMapping.try_emplace(Key: aliasScopeAttr, Args: nullptr);
1914 if (!scopeInserted)
1915 return scopeIt->second;
1916 llvm::LLVMContext &ctx = llvmModule->getContext();
1917 auto dummy = llvm::MDNode::getTemporary(Context&: ctx, MDs: {});
1918 // Convert the domain metadata node if necessary.
1919 auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace(
1920 Key: aliasScopeAttr.getDomain(), Args: nullptr);
1921 if (insertedDomain) {
1922 llvm::SmallVector<llvm::Metadata *, 2> operands;
1923 // Placeholder for potential self-reference.
1924 operands.push_back(Elt: dummy.get());
1925 if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
1926 operands.push_back(Elt: llvm::MDString::get(Context&: ctx, Str: description));
1927 domainIt->second = llvm::MDNode::get(Context&: ctx, MDs: operands);
1928 // Self-reference for uniqueness.
1929 llvm::Metadata *replacement;
1930 if (auto stringAttr =
1931 dyn_cast<StringAttr>(Val: aliasScopeAttr.getDomain().getId()))
1932 replacement = llvm::MDString::get(Context&: ctx, Str: stringAttr.getValue());
1933 else
1934 replacement = domainIt->second;
1935 domainIt->second->replaceOperandWith(I: 0, New: replacement);
1936 }
1937 // Convert the scope metadata node.
1938 assert(domainIt->second && "Scope's domain should already be valid");
1939 llvm::SmallVector<llvm::Metadata *, 3> operands;
1940 // Placeholder for potential self-reference.
1941 operands.push_back(Elt: dummy.get());
1942 operands.push_back(Elt: domainIt->second);
1943 if (StringAttr description = aliasScopeAttr.getDescription())
1944 operands.push_back(Elt: llvm::MDString::get(Context&: ctx, Str: description));
1945 scopeIt->second = llvm::MDNode::get(Context&: ctx, MDs: operands);
1946 // Self-reference for uniqueness.
1947 llvm::Metadata *replacement;
1948 if (auto stringAttr = dyn_cast<StringAttr>(Val: aliasScopeAttr.getId()))
1949 replacement = llvm::MDString::get(Context&: ctx, Str: stringAttr.getValue());
1950 else
1951 replacement = scopeIt->second;
1952 scopeIt->second->replaceOperandWith(I: 0, New: replacement);
1953 return scopeIt->second;
1954}
1955
1956llvm::MDNode *ModuleTranslation::getOrCreateAliasScopes(
1957 ArrayRef<AliasScopeAttr> aliasScopeAttrs) {
1958 SmallVector<llvm::Metadata *> nodes;
1959 nodes.reserve(N: aliasScopeAttrs.size());
1960 for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs)
1961 nodes.push_back(Elt: getOrCreateAliasScope(aliasScopeAttr));
1962 return llvm::MDNode::get(Context&: getLLVMContext(), MDs: nodes);
1963}
1964
1965void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
1966 llvm::Instruction *inst) {
1967 auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) {
1968 if (!aliasScopeAttrs || aliasScopeAttrs.empty())
1969 return;
1970 llvm::MDNode *node = getOrCreateAliasScopes(
1971 aliasScopeAttrs: llvm::to_vector(Range: aliasScopeAttrs.getAsRange<AliasScopeAttr>()));
1972 inst->setMetadata(KindID: kind, Node: node);
1973 };
1974
1975 populateScopeMetadata(op.getAliasScopesOrNull(),
1976 llvm::LLVMContext::MD_alias_scope);
1977 populateScopeMetadata(op.getNoAliasScopesOrNull(),
1978 llvm::LLVMContext::MD_noalias);
1979}
1980
1981llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
1982 return tbaaMetadataMapping.lookup(Val: tbaaAttr);
1983}
1984
1985void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
1986 llvm::Instruction *inst) {
1987 ArrayAttr tagRefs = op.getTBAATagsOrNull();
1988 if (!tagRefs || tagRefs.empty())
1989 return;
1990
1991 // LLVM IR currently does not support attaching more than one TBAA access tag
1992 // to a memory accessing instruction. It may be useful to support this in
1993 // future, but for the time being just ignore the metadata if MLIR operation
1994 // has multiple access tags.
1995 if (tagRefs.size() > 1) {
1996 op.emitWarning() << "TBAA access tags were not translated, because LLVM "
1997 "IR only supports a single tag per instruction";
1998 return;
1999 }
2000
2001 llvm::MDNode *node = getTBAANode(tbaaAttr: cast<TBAATagAttr>(Val: tagRefs[0]));
2002 inst->setMetadata(KindID: llvm::LLVMContext::MD_tbaa, Node: node);
2003}
2004
2005void ModuleTranslation::setDereferenceableMetadata(
2006 DereferenceableOpInterface op, llvm::Instruction *inst) {
2007 DereferenceableAttr derefAttr = op.getDereferenceableOrNull();
2008 if (!derefAttr)
2009 return;
2010
2011 llvm::MDNode *derefSizeNode = llvm::MDNode::get(
2012 Context&: getLLVMContext(),
2013 MDs: llvm::ConstantAsMetadata::get(C: llvm::ConstantInt::get(
2014 Ty: llvm::IntegerType::get(C&: getLLVMContext(), NumBits: 64), V: derefAttr.getBytes())));
2015 unsigned kindId = derefAttr.getMayBeNull()
2016 ? llvm::LLVMContext::MD_dereferenceable_or_null
2017 : llvm::LLVMContext::MD_dereferenceable;
2018 inst->setMetadata(KindID: kindId, Node: derefSizeNode);
2019}
2020
2021void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) {
2022 SmallVector<uint32_t> weights;
2023 llvm::transform(Range: op.getWeights(), d_first: std::back_inserter(x&: weights),
2024 F: [](int32_t value) { return static_cast<uint32_t>(value); });
2025 if (weights.empty())
2026 return;
2027
2028 llvm::Instruction *inst = isa<CallOp>(Val: op) ? lookupCall(op) : lookupBranch(op);
2029 assert(inst && "expected the operation to have a mapping to an instruction");
2030 inst->setMetadata(
2031 KindID: llvm::LLVMContext::MD_prof,
2032 Node: llvm::MDBuilder(getLLVMContext()).createBranchWeights(Weights: weights));
2033}
2034
2035LogicalResult ModuleTranslation::createTBAAMetadata() {
2036 llvm::LLVMContext &ctx = llvmModule->getContext();
2037 llvm::IntegerType *offsetTy = llvm::IntegerType::get(C&: ctx, NumBits: 64);
2038
2039 // Walk the entire module and create all metadata nodes for the TBAA
2040 // attributes. The code below relies on two invariants of the
2041 // `AttrTypeWalker`:
2042 // 1. Attributes are visited in post-order: Since the attributes create a DAG,
2043 // this ensures that any lookups into `tbaaMetadataMapping` for child
2044 // attributes succeed.
2045 // 2. Attributes are only ever visited once: This way we don't leak any
2046 // LLVM metadata instances.
2047 AttrTypeWalker walker;
2048 walker.addWalk(callback: [&](TBAARootAttr root) {
2049 tbaaMetadataMapping.insert(
2050 KV: {root, llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(Context&: ctx, Str: root.getId()))});
2051 });
2052
2053 walker.addWalk(callback: [&](TBAATypeDescriptorAttr descriptor) {
2054 SmallVector<llvm::Metadata *> operands;
2055 operands.push_back(Elt: llvm::MDString::get(Context&: ctx, Str: descriptor.getId()));
2056 for (TBAAMemberAttr member : descriptor.getMembers()) {
2057 operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: member.getTypeDesc()));
2058 operands.push_back(Elt: llvm::ConstantAsMetadata::get(
2059 C: llvm::ConstantInt::get(Ty: offsetTy, V: member.getOffset())));
2060 }
2061
2062 tbaaMetadataMapping.insert(KV: {descriptor, llvm::MDNode::get(Context&: ctx, MDs: operands)});
2063 });
2064
2065 walker.addWalk(callback: [&](TBAATagAttr tag) {
2066 SmallVector<llvm::Metadata *> operands;
2067
2068 operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: tag.getBaseType()));
2069 operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: tag.getAccessType()));
2070
2071 operands.push_back(Elt: llvm::ConstantAsMetadata::get(
2072 C: llvm::ConstantInt::get(Ty: offsetTy, V: tag.getOffset())));
2073 if (tag.getConstant())
2074 operands.push_back(
2075 Elt: llvm::ConstantAsMetadata::get(C: llvm::ConstantInt::get(Ty: offsetTy, V: 1)));
2076
2077 tbaaMetadataMapping.insert(KV: {tag, llvm::MDNode::get(Context&: ctx, MDs: operands)});
2078 });
2079
2080 mlirModule->walk(callback: [&](AliasAnalysisOpInterface analysisOpInterface) {
2081 if (auto attr = analysisOpInterface.getTBAATagsOrNull())
2082 walker.walk(element: attr);
2083 });
2084
2085 return success();
2086}
2087
2088LogicalResult ModuleTranslation::createIdentMetadata() {
2089 if (auto attr = mlirModule->getAttrOfType<StringAttr>(
2090 name: LLVMDialect::getIdentAttrName())) {
2091 StringRef ident = attr;
2092 llvm::LLVMContext &ctx = llvmModule->getContext();
2093 llvm::NamedMDNode *namedMd =
2094 llvmModule->getOrInsertNamedMetadata(Name: LLVMDialect::getIdentAttrName());
2095 llvm::MDNode *md = llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(Context&: ctx, Str: ident));
2096 namedMd->addOperand(M: md);
2097 }
2098
2099 return success();
2100}
2101
2102LogicalResult ModuleTranslation::createCommandlineMetadata() {
2103 if (auto attr = mlirModule->getAttrOfType<StringAttr>(
2104 name: LLVMDialect::getCommandlineAttrName())) {
2105 StringRef cmdLine = attr;
2106 llvm::LLVMContext &ctx = llvmModule->getContext();
2107 llvm::NamedMDNode *nmd = llvmModule->getOrInsertNamedMetadata(
2108 Name: LLVMDialect::getCommandlineAttrName());
2109 llvm::MDNode *md =
2110 llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(Context&: ctx, Str: cmdLine));
2111 nmd->addOperand(M: md);
2112 }
2113
2114 return success();
2115}
2116
2117LogicalResult ModuleTranslation::createDependentLibrariesMetadata() {
2118 if (auto dependentLibrariesAttr = mlirModule->getDiscardableAttr(
2119 name: LLVM::LLVMDialect::getDependentLibrariesAttrName())) {
2120 auto *nmd =
2121 llvmModule->getOrInsertNamedMetadata(Name: "llvm.dependent-libraries");
2122 llvm::LLVMContext &ctx = llvmModule->getContext();
2123 for (auto libAttr :
2124 cast<ArrayAttr>(Val&: dependentLibrariesAttr).getAsRange<StringAttr>()) {
2125 auto *md =
2126 llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(Context&: ctx, Str: libAttr.getValue()));
2127 nmd->addOperand(M: md);
2128 }
2129 }
2130 return success();
2131}
2132
2133void ModuleTranslation::setLoopMetadata(Operation *op,
2134 llvm::Instruction *inst) {
2135 LoopAnnotationAttr attr =
2136 TypeSwitch<Operation *, LoopAnnotationAttr>(op)
2137 .Case<LLVM::BrOp, LLVM::CondBrOp>(
2138 caseFn: [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
2139 if (!attr)
2140 return;
2141 llvm::MDNode *loopMD =
2142 loopAnnotationTranslation->translateLoopAnnotation(attr, op);
2143 inst->setMetadata(KindID: llvm::LLVMContext::MD_loop, Node: loopMD);
2144}
2145
2146void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) {
2147 auto iface = cast<DisjointFlagInterface>(Val: op);
2148 // We do a dyn_cast here in case the value got folded into a constant.
2149 if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(Val: value))
2150 disjointInst->setIsDisjoint(iface.getIsDisjoint());
2151}
2152
2153llvm::Type *ModuleTranslation::convertType(Type type) {
2154 return typeTranslator.translateType(type);
2155}
2156
2157/// A helper to look up remapped operands in the value remapping table.
2158SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) {
2159 SmallVector<llvm::Value *> remapped;
2160 remapped.reserve(N: values.size());
2161 for (Value v : values)
2162 remapped.push_back(Elt: lookupValue(value: v));
2163 return remapped;
2164}
2165
2166llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
2167 if (!ompBuilder) {
2168 ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(args&: *llvmModule);
2169 ompBuilder->initialize();
2170
2171 // Flags represented as top-level OpenMP dialect attributes are set in
2172 // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set
2173 // the default configuration.
2174 ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig(
2175 /* IsTargetDevice = */ false, /* IsGPU = */ false,
2176 /* OpenMPOffloadMandatory = */ false,
2177 /* HasRequiresReverseOffload = */ false,
2178 /* HasRequiresUnifiedAddress = */ false,
2179 /* HasRequiresUnifiedSharedMemory = */ false,
2180 /* HasRequiresDynamicAllocators = */ false));
2181 }
2182 return ompBuilder.get();
2183}
2184
2185llvm::DILocation *ModuleTranslation::translateLoc(Location loc,
2186 llvm::DILocalScope *scope) {
2187 return debugTranslation->translateLoc(loc, scope);
2188}
2189
2190llvm::DIExpression *
2191ModuleTranslation::translateExpression(LLVM::DIExpressionAttr attr) {
2192 return debugTranslation->translateExpression(attr);
2193}
2194
2195llvm::DIGlobalVariableExpression *
2196ModuleTranslation::translateGlobalVariableExpression(
2197 LLVM::DIGlobalVariableExpressionAttr attr) {
2198 return debugTranslation->translateGlobalVariableExpression(attr);
2199}
2200
2201llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
2202 return debugTranslation->translate(attr);
2203}
2204
2205llvm::RoundingMode
2206ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
2207 return convertRoundingModeToLLVM(value: rounding);
2208}
2209
2210llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior(
2211 LLVM::FPExceptionBehavior exceptionBehavior) {
2212 return convertFPExceptionBehaviorToLLVM(value: exceptionBehavior);
2213}
2214
2215llvm::NamedMDNode *
2216ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
2217 return llvmModule->getOrInsertNamedMetadata(Name: name);
2218}
2219
2220static std::unique_ptr<llvm::Module>
2221prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
2222 StringRef name) {
2223 m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
2224 auto llvmModule = std::make_unique<llvm::Module>(args&: name, args&: llvmContext);
2225 if (auto dataLayoutAttr =
2226 m->getDiscardableAttr(name: LLVM::LLVMDialect::getDataLayoutAttrName())) {
2227 llvmModule->setDataLayout(cast<StringAttr>(Val&: dataLayoutAttr).getValue());
2228 } else {
2229 FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
2230 if (auto iface = dyn_cast<DataLayoutOpInterface>(Val: m)) {
2231 if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
2232 llvmDataLayout =
2233 translateDataLayout(attribute: spec, dataLayout: DataLayout(iface), loc: m->getLoc());
2234 }
2235 } else if (auto mod = dyn_cast<ModuleOp>(Val: m)) {
2236 if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
2237 llvmDataLayout =
2238 translateDataLayout(attribute: spec, dataLayout: DataLayout(mod), loc: m->getLoc());
2239 }
2240 }
2241 if (failed(Result: llvmDataLayout))
2242 return nullptr;
2243 llvmModule->setDataLayout(*llvmDataLayout);
2244 }
2245 if (auto targetTripleAttr =
2246 m->getDiscardableAttr(name: LLVM::LLVMDialect::getTargetTripleAttrName()))
2247 llvmModule->setTargetTriple(
2248 llvm::Triple(cast<StringAttr>(Val&: targetTripleAttr).getValue()));
2249
2250 return llvmModule;
2251}
2252
2253std::unique_ptr<llvm::Module>
2254mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
2255 StringRef name, bool disableVerification) {
2256 if (!satisfiesLLVMModule(op: module)) {
2257 module->emitOpError(message: "can not be translated to an LLVMIR module");
2258 return nullptr;
2259 }
2260
2261 std::unique_ptr<llvm::Module> llvmModule =
2262 prepareLLVMModule(m: module, llvmContext, name);
2263 if (!llvmModule)
2264 return nullptr;
2265
2266 LLVM::ensureDistinctSuccessors(op: module);
2267 LLVM::legalizeDIExpressionsRecursively(op: module);
2268
2269 ModuleTranslation translator(module, std::move(llvmModule));
2270 llvm::IRBuilder<llvm::TargetFolder> llvmBuilder(
2271 llvmContext,
2272 llvm::TargetFolder(translator.getLLVMModule()->getDataLayout()));
2273
2274 // Convert module before functions and operations inside, so dialect
2275 // attributes can be used to change dialect-specific global configurations via
2276 // `amendOperation()`. These configurations can then influence the translation
2277 // of operations afterwards.
2278 if (failed(Result: translator.convertOperation(op&: *module, builder&: llvmBuilder)))
2279 return nullptr;
2280
2281 if (failed(Result: translator.convertComdats()))
2282 return nullptr;
2283 if (failed(Result: translator.convertFunctionSignatures()))
2284 return nullptr;
2285 if (failed(Result: translator.convertGlobalsAndAliases()))
2286 return nullptr;
2287 if (failed(Result: translator.createTBAAMetadata()))
2288 return nullptr;
2289 if (failed(Result: translator.createIdentMetadata()))
2290 return nullptr;
2291 if (failed(Result: translator.createCommandlineMetadata()))
2292 return nullptr;
2293 if (failed(Result: translator.createDependentLibrariesMetadata()))
2294 return nullptr;
2295
2296 // Convert other top-level operations if possible.
2297 for (Operation &o : getModuleBody(module).getOperations()) {
2298 if (!isa<LLVM::LLVMFuncOp, LLVM::AliasOp, LLVM::GlobalOp,
2299 LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(Val: &o) &&
2300 !o.hasTrait<OpTrait::IsTerminator>() &&
2301 failed(Result: translator.convertOperation(op&: o, builder&: llvmBuilder))) {
2302 return nullptr;
2303 }
2304 }
2305
2306 // Operations in function bodies with symbolic references must be converted
2307 // after the top-level operations they refer to are declared, so we do it
2308 // last.
2309 if (failed(Result: translator.convertFunctions()))
2310 return nullptr;
2311
2312 // Now that all MLIR blocks are resolved into LLVM ones, patch block address
2313 // constants to point to the correct blocks.
2314 if (failed(Result: translator.convertUnresolvedBlockAddress()))
2315 return nullptr;
2316
2317 // Once we've finished constructing elements in the module, we should convert
2318 // it to use the debug info format desired by LLVM.
2319 // See https://llvm.org/docs/RemoveDIsDebugInfo.html
2320 translator.llvmModule->convertToNewDbgValues();
2321
2322 // Add the necessary debug info module flags, if they were not encoded in MLIR
2323 // beforehand.
2324 translator.debugTranslation->addModuleFlagsIfNotPresent();
2325
2326 // Call the OpenMP IR Builder callbacks prior to verifying the module
2327 if (auto *ompBuilder = translator.getOpenMPBuilder())
2328 ompBuilder->finalize();
2329
2330 if (!disableVerification &&
2331 llvm::verifyModule(M: *translator.llvmModule, OS: &llvm::errs()))
2332 return nullptr;
2333
2334 return std::move(translator.llvmModule);
2335}
2336

source code of mlir/lib/Target/LLVMIR/ModuleTranslation.cpp