| 1 | //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the translation between an MLIR LLVM dialect module and |
| 10 | // the corresponding LLVMIR module. It only handles core LLVM IR operations. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
| 15 | |
| 16 | #include "AttrKindDetail.h" |
| 17 | #include "DebugTranslation.h" |
| 18 | #include "LoopAnnotationTranslation.h" |
| 19 | #include "mlir/Analysis/TopologicalSortUtils.h" |
| 20 | #include "mlir/Dialect/DLTI/DLTI.h" |
| 21 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 22 | #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" |
| 23 | #include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h" |
| 24 | #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" |
| 25 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| 26 | #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" |
| 27 | #include "mlir/IR/AttrTypeSubElements.h" |
| 28 | #include "mlir/IR/Attributes.h" |
| 29 | #include "mlir/IR/BuiltinOps.h" |
| 30 | #include "mlir/IR/BuiltinTypes.h" |
| 31 | #include "mlir/IR/DialectResourceBlobManager.h" |
| 32 | #include "mlir/IR/RegionGraphTraits.h" |
| 33 | #include "mlir/Support/LLVM.h" |
| 34 | #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" |
| 35 | #include "mlir/Target/LLVMIR/TypeToLLVM.h" |
| 36 | |
| 37 | #include "llvm/ADT/PostOrderIterator.h" |
| 38 | #include "llvm/ADT/STLExtras.h" |
| 39 | #include "llvm/ADT/SetVector.h" |
| 40 | #include "llvm/ADT/StringExtras.h" |
| 41 | #include "llvm/ADT/TypeSwitch.h" |
| 42 | #include "llvm/Analysis/TargetFolder.h" |
| 43 | #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" |
| 44 | #include "llvm/IR/BasicBlock.h" |
| 45 | #include "llvm/IR/CFG.h" |
| 46 | #include "llvm/IR/Constants.h" |
| 47 | #include "llvm/IR/DerivedTypes.h" |
| 48 | #include "llvm/IR/IRBuilder.h" |
| 49 | #include "llvm/IR/InlineAsm.h" |
| 50 | #include "llvm/IR/IntrinsicsNVPTX.h" |
| 51 | #include "llvm/IR/LLVMContext.h" |
| 52 | #include "llvm/IR/MDBuilder.h" |
| 53 | #include "llvm/IR/Module.h" |
| 54 | #include "llvm/IR/Verifier.h" |
| 55 | #include "llvm/Support/Debug.h" |
| 56 | #include "llvm/Support/ErrorHandling.h" |
| 57 | #include "llvm/Support/raw_ostream.h" |
| 58 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 59 | #include "llvm/Transforms/Utils/Cloning.h" |
| 60 | #include "llvm/Transforms/Utils/ModuleUtils.h" |
| 61 | #include <numeric> |
| 62 | #include <optional> |
| 63 | |
| 64 | #define DEBUG_TYPE "llvm-dialect-to-llvm-ir" |
| 65 | |
| 66 | using namespace mlir; |
| 67 | using namespace mlir::LLVM; |
| 68 | using namespace mlir::LLVM::detail; |
| 69 | |
| 70 | #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" |
| 71 | |
| 72 | namespace { |
| 73 | /// A customized inserter for LLVM's IRBuilder that captures all LLVM IR |
| 74 | /// instructions that are created for future reference. |
| 75 | /// |
| 76 | /// This is intended to be used with the `CollectionScope` RAII object: |
| 77 | /// |
| 78 | /// llvm::IRBuilder<..., InstructionCapturingInserter> builder; |
| 79 | /// { |
| 80 | /// InstructionCapturingInserter::CollectionScope scope(builder); |
| 81 | /// // Call IRBuilder methods as usual. |
| 82 | /// |
| 83 | /// // This will return a list of all instructions created by the builder, |
| 84 | /// // in order of creation. |
| 85 | /// builder.getInserter().getCapturedInstructions(); |
| 86 | /// } |
| 87 | /// // This will return an empty list. |
| 88 | /// builder.getInserter().getCapturedInstructions(); |
| 89 | /// |
| 90 | /// The capturing functionality is _disabled_ by default for performance |
| 91 | /// consideration. It needs to be explicitly enabled, which is achieved by |
| 92 | /// creating a `CollectionScope`. |
| 93 | class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter { |
| 94 | public: |
| 95 | /// Constructs the inserter. |
| 96 | InstructionCapturingInserter() |
| 97 | : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) { |
| 98 | if (LLVM_LIKELY(enabled)) |
| 99 | capturedInstructions.push_back(instruction); |
| 100 | }) {} |
| 101 | |
| 102 | /// Returns the list of LLVM IR instructions captured since the last cleanup. |
| 103 | ArrayRef<llvm::Instruction *> getCapturedInstructions() const { |
| 104 | return capturedInstructions; |
| 105 | } |
| 106 | |
| 107 | /// Clears the list of captured LLVM IR instructions. |
| 108 | void clearCapturedInstructions() { capturedInstructions.clear(); } |
| 109 | |
| 110 | /// RAII object enabling the capture of created LLVM IR instructions. |
| 111 | class CollectionScope { |
| 112 | public: |
| 113 | /// Creates the scope for the given inserter. |
| 114 | CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing); |
| 115 | |
| 116 | /// Ends the scope. |
| 117 | ~CollectionScope(); |
| 118 | |
| 119 | ArrayRef<llvm::Instruction *> getCapturedInstructions() { |
| 120 | if (!inserter) |
| 121 | return {}; |
| 122 | return inserter->getCapturedInstructions(); |
| 123 | } |
| 124 | |
| 125 | private: |
| 126 | /// Back reference to the inserter. |
| 127 | InstructionCapturingInserter *inserter = nullptr; |
| 128 | |
| 129 | /// List of instructions in the inserter prior to this scope. |
| 130 | SmallVector<llvm::Instruction *> previouslyCollectedInstructions; |
| 131 | |
| 132 | /// Whether the inserter was enabled prior to this scope. |
| 133 | bool wasEnabled; |
| 134 | }; |
| 135 | |
| 136 | /// Enable or disable the capturing mechanism. |
| 137 | void setEnabled(bool enabled = true) { this->enabled = enabled; } |
| 138 | |
| 139 | private: |
| 140 | /// List of captured instructions. |
| 141 | SmallVector<llvm::Instruction *> capturedInstructions; |
| 142 | |
| 143 | /// Whether the collection is enabled. |
| 144 | bool enabled = false; |
| 145 | }; |
| 146 | |
| 147 | using CapturingIRBuilder = |
| 148 | llvm::IRBuilder<llvm::TargetFolder, InstructionCapturingInserter>; |
| 149 | } // namespace |
| 150 | |
| 151 | InstructionCapturingInserter::CollectionScope::CollectionScope( |
| 152 | llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) { |
| 153 | |
| 154 | if (!isBuilderCapturing) |
| 155 | return; |
| 156 | |
| 157 | auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder); |
| 158 | inserter = &capturingIRBuilder.getInserter(); |
| 159 | wasEnabled = inserter->enabled; |
| 160 | if (wasEnabled) |
| 161 | previouslyCollectedInstructions.swap(inserter->capturedInstructions); |
| 162 | inserter->setEnabled(true); |
| 163 | } |
| 164 | |
| 165 | InstructionCapturingInserter::CollectionScope::~CollectionScope() { |
| 166 | if (!inserter) |
| 167 | return; |
| 168 | |
| 169 | previouslyCollectedInstructions.swap(inserter->capturedInstructions); |
| 170 | // If collection was enabled (likely in another, surrounding scope), keep |
| 171 | // the instructions collected in this scope. |
| 172 | if (wasEnabled) { |
| 173 | llvm::append_range(inserter->capturedInstructions, |
| 174 | previouslyCollectedInstructions); |
| 175 | } |
| 176 | inserter->setEnabled(wasEnabled); |
| 177 | } |
| 178 | |
| 179 | /// Translates the given data layout spec attribute to the LLVM IR data layout. |
| 180 | /// Only integer, float, pointer and endianness entries are currently supported. |
| 181 | static FailureOr<llvm::DataLayout> |
| 182 | translateDataLayout(DataLayoutSpecInterface attribute, |
| 183 | const DataLayout &dataLayout, |
| 184 | std::optional<Location> loc = std::nullopt) { |
| 185 | if (!loc) |
| 186 | loc = UnknownLoc::get(attribute.getContext()); |
| 187 | |
| 188 | // Translate the endianness attribute. |
| 189 | std::string llvmDataLayout; |
| 190 | llvm::raw_string_ostream layoutStream(llvmDataLayout); |
| 191 | for (DataLayoutEntryInterface entry : attribute.getEntries()) { |
| 192 | auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey()); |
| 193 | if (!key) |
| 194 | continue; |
| 195 | if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) { |
| 196 | auto value = cast<StringAttr>(entry.getValue()); |
| 197 | bool isLittleEndian = |
| 198 | value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle; |
| 199 | layoutStream << "-" << (isLittleEndian ? "e" : "E" ); |
| 200 | continue; |
| 201 | } |
| 202 | if (key.getValue() == DLTIDialect::kDataLayoutManglingModeKey) { |
| 203 | auto value = cast<StringAttr>(entry.getValue()); |
| 204 | layoutStream << "-m:" << value.getValue(); |
| 205 | continue; |
| 206 | } |
| 207 | if (key.getValue() == DLTIDialect::kDataLayoutProgramMemorySpaceKey) { |
| 208 | auto value = cast<IntegerAttr>(entry.getValue()); |
| 209 | uint64_t space = value.getValue().getZExtValue(); |
| 210 | // Skip the default address space. |
| 211 | if (space == 0) |
| 212 | continue; |
| 213 | layoutStream << "-P" << space; |
| 214 | continue; |
| 215 | } |
| 216 | if (key.getValue() == DLTIDialect::kDataLayoutGlobalMemorySpaceKey) { |
| 217 | auto value = cast<IntegerAttr>(entry.getValue()); |
| 218 | uint64_t space = value.getValue().getZExtValue(); |
| 219 | // Skip the default address space. |
| 220 | if (space == 0) |
| 221 | continue; |
| 222 | layoutStream << "-G" << space; |
| 223 | continue; |
| 224 | } |
| 225 | if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) { |
| 226 | auto value = cast<IntegerAttr>(entry.getValue()); |
| 227 | uint64_t space = value.getValue().getZExtValue(); |
| 228 | // Skip the default address space. |
| 229 | if (space == 0) |
| 230 | continue; |
| 231 | layoutStream << "-A" << space; |
| 232 | continue; |
| 233 | } |
| 234 | if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) { |
| 235 | auto value = cast<IntegerAttr>(entry.getValue()); |
| 236 | uint64_t alignment = value.getValue().getZExtValue(); |
| 237 | // Skip the default stack alignment. |
| 238 | if (alignment == 0) |
| 239 | continue; |
| 240 | layoutStream << "-S" << alignment; |
| 241 | continue; |
| 242 | } |
| 243 | if (key.getValue() == DLTIDialect::kDataLayoutFunctionPointerAlignmentKey) { |
| 244 | auto value = cast<FunctionPointerAlignmentAttr>(entry.getValue()); |
| 245 | uint64_t alignment = value.getAlignment(); |
| 246 | // Skip the default function pointer alignment. |
| 247 | if (alignment == 0) |
| 248 | continue; |
| 249 | layoutStream << "-F" << (value.getFunctionDependent() ? "n" : "i" ) |
| 250 | << alignment; |
| 251 | continue; |
| 252 | } |
| 253 | if (key.getValue() == DLTIDialect::kDataLayoutLegalIntWidthsKey) { |
| 254 | layoutStream << "-n" ; |
| 255 | llvm::interleave( |
| 256 | cast<DenseI32ArrayAttr>(entry.getValue()).asArrayRef(), layoutStream, |
| 257 | [&](int32_t val) { layoutStream << val; }, ":" ); |
| 258 | continue; |
| 259 | } |
| 260 | emitError(*loc) << "unsupported data layout key " << key; |
| 261 | return failure(); |
| 262 | } |
| 263 | |
| 264 | // Go through the list of entries to check which types are explicitly |
| 265 | // specified in entries. Where possible, data layout queries are used instead |
| 266 | // of directly inspecting the entries. |
| 267 | for (DataLayoutEntryInterface entry : attribute.getEntries()) { |
| 268 | auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()); |
| 269 | if (!type) |
| 270 | continue; |
| 271 | // Data layout for the index type is irrelevant at this point. |
| 272 | if (isa<IndexType>(type)) |
| 273 | continue; |
| 274 | layoutStream << "-" ; |
| 275 | LogicalResult result = |
| 276 | llvm::TypeSwitch<Type, LogicalResult>(type) |
| 277 | .Case<IntegerType, Float16Type, Float32Type, Float64Type, |
| 278 | Float80Type, Float128Type>([&](Type type) -> LogicalResult { |
| 279 | if (auto intType = dyn_cast<IntegerType>(type)) { |
| 280 | if (intType.getSignedness() != IntegerType::Signless) |
| 281 | return emitError(*loc) |
| 282 | << "unsupported data layout for non-signless integer " |
| 283 | << intType; |
| 284 | layoutStream << "i" ; |
| 285 | } else { |
| 286 | layoutStream << "f" ; |
| 287 | } |
| 288 | uint64_t size = dataLayout.getTypeSizeInBits(type); |
| 289 | uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u; |
| 290 | uint64_t preferred = |
| 291 | dataLayout.getTypePreferredAlignment(type) * 8u; |
| 292 | layoutStream << size << ":" << abi; |
| 293 | if (abi != preferred) |
| 294 | layoutStream << ":" << preferred; |
| 295 | return success(); |
| 296 | }) |
| 297 | .Case([&](LLVMPointerType type) { |
| 298 | layoutStream << "p" << type.getAddressSpace() << ":" ; |
| 299 | uint64_t size = dataLayout.getTypeSizeInBits(type); |
| 300 | uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u; |
| 301 | uint64_t preferred = |
| 302 | dataLayout.getTypePreferredAlignment(type) * 8u; |
| 303 | uint64_t index = *dataLayout.getTypeIndexBitwidth(type); |
| 304 | layoutStream << size << ":" << abi << ":" << preferred << ":" |
| 305 | << index; |
| 306 | return success(); |
| 307 | }) |
| 308 | .Default([loc](Type type) { |
| 309 | return emitError(*loc) |
| 310 | << "unsupported type in data layout: " << type; |
| 311 | }); |
| 312 | if (failed(result)) |
| 313 | return failure(); |
| 314 | } |
| 315 | StringRef layoutSpec(llvmDataLayout); |
| 316 | layoutSpec.consume_front(Prefix: "-" ); |
| 317 | |
| 318 | return llvm::DataLayout(layoutSpec); |
| 319 | } |
| 320 | |
| 321 | /// Builds a constant of a sequential LLVM type `type`, potentially containing |
| 322 | /// other sequential types recursively, from the individual constant values |
| 323 | /// provided in `constants`. `shape` contains the number of elements in nested |
| 324 | /// sequential types. Reports errors at `loc` and returns nullptr on error. |
| 325 | static llvm::Constant * |
| 326 | buildSequentialConstant(ArrayRef<llvm::Constant *> &constants, |
| 327 | ArrayRef<int64_t> shape, llvm::Type *type, |
| 328 | Location loc) { |
| 329 | if (shape.empty()) { |
| 330 | llvm::Constant *result = constants.front(); |
| 331 | constants = constants.drop_front(); |
| 332 | return result; |
| 333 | } |
| 334 | |
| 335 | llvm::Type *elementType; |
| 336 | if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: type)) { |
| 337 | elementType = arrayTy->getElementType(); |
| 338 | } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(Val: type)) { |
| 339 | elementType = vectorTy->getElementType(); |
| 340 | } else { |
| 341 | emitError(loc) << "expected sequential LLVM types wrapping a scalar" ; |
| 342 | return nullptr; |
| 343 | } |
| 344 | |
| 345 | SmallVector<llvm::Constant *, 8> nested; |
| 346 | nested.reserve(N: shape.front()); |
| 347 | for (int64_t i = 0; i < shape.front(); ++i) { |
| 348 | nested.push_back(Elt: buildSequentialConstant(constants, shape: shape.drop_front(), |
| 349 | type: elementType, loc)); |
| 350 | if (!nested.back()) |
| 351 | return nullptr; |
| 352 | } |
| 353 | |
| 354 | if (shape.size() == 1 && type->isVectorTy()) |
| 355 | return llvm::ConstantVector::get(V: nested); |
| 356 | return llvm::ConstantArray::get( |
| 357 | T: llvm::ArrayType::get(ElementType: elementType, NumElements: shape.front()), V: nested); |
| 358 | } |
| 359 | |
| 360 | /// Returns the first non-sequential type nested in sequential types. |
| 361 | static llvm::Type *getInnermostElementType(llvm::Type *type) { |
| 362 | do { |
| 363 | if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: type)) { |
| 364 | type = arrayTy->getElementType(); |
| 365 | } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(Val: type)) { |
| 366 | type = vectorTy->getElementType(); |
| 367 | } else { |
| 368 | return type; |
| 369 | } |
| 370 | } while (true); |
| 371 | } |
| 372 | |
| 373 | /// Convert a dense elements attribute to an LLVM IR constant using its raw data |
| 374 | /// storage if possible. This supports elements attributes of tensor or vector |
| 375 | /// type and avoids constructing separate objects for individual values of the |
| 376 | /// innermost dimension. Constants for other dimensions are still constructed |
| 377 | /// recursively. Returns null if constructing from raw data is not supported for |
| 378 | /// this type, e.g., element type is not a power-of-two-sized primitive. Reports |
| 379 | /// other errors at `loc`. |
| 380 | static llvm::Constant * |
| 381 | convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, |
| 382 | llvm::Type *llvmType, |
| 383 | const ModuleTranslation &moduleTranslation) { |
| 384 | if (!denseElementsAttr) |
| 385 | return nullptr; |
| 386 | |
| 387 | llvm::Type *innermostLLVMType = getInnermostElementType(type: llvmType); |
| 388 | if (!llvm::ConstantDataSequential::isElementTypeCompatible(Ty: innermostLLVMType)) |
| 389 | return nullptr; |
| 390 | |
| 391 | ShapedType type = denseElementsAttr.getType(); |
| 392 | if (type.getNumElements() == 0) |
| 393 | return nullptr; |
| 394 | |
| 395 | // Check that the raw data size matches what is expected for the scalar size. |
| 396 | // TODO: in theory, we could repack the data here to keep constructing from |
| 397 | // raw data. |
| 398 | // TODO: we may also need to consider endianness when cross-compiling to an |
| 399 | // architecture where it is different. |
| 400 | int64_t elementByteSize = denseElementsAttr.getRawData().size() / |
| 401 | denseElementsAttr.getNumElements(); |
| 402 | if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) |
| 403 | return nullptr; |
| 404 | |
| 405 | // Compute the shape of all dimensions but the innermost. Note that the |
| 406 | // innermost dimension may be that of the vector element type. |
| 407 | bool hasVectorElementType = isa<VectorType>(type.getElementType()); |
| 408 | int64_t numAggregates = |
| 409 | denseElementsAttr.getNumElements() / |
| 410 | (hasVectorElementType ? 1 |
| 411 | : denseElementsAttr.getType().getShape().back()); |
| 412 | ArrayRef<int64_t> outerShape = type.getShape(); |
| 413 | if (!hasVectorElementType) |
| 414 | outerShape = outerShape.drop_back(); |
| 415 | |
| 416 | // Handle the case of vector splat, LLVM has special support for it. |
| 417 | if (denseElementsAttr.isSplat() && |
| 418 | (isa<VectorType>(type) || hasVectorElementType)) { |
| 419 | llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( |
| 420 | llvmType: innermostLLVMType, attr: denseElementsAttr.getSplatValue<Attribute>(), loc, |
| 421 | moduleTranslation); |
| 422 | llvm::Constant *splatVector = |
| 423 | llvm::ConstantDataVector::getSplat(NumElts: 0, Elt: splatValue); |
| 424 | SmallVector<llvm::Constant *> constants(numAggregates, splatVector); |
| 425 | ArrayRef<llvm::Constant *> constantsRef = constants; |
| 426 | return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc); |
| 427 | } |
| 428 | if (denseElementsAttr.isSplat()) |
| 429 | return nullptr; |
| 430 | |
| 431 | // In case of non-splat, create a constructor for the innermost constant from |
| 432 | // a piece of raw data. |
| 433 | std::function<llvm::Constant *(StringRef)> buildCstData; |
| 434 | if (isa<TensorType>(type)) { |
| 435 | auto vectorElementType = dyn_cast<VectorType>(type.getElementType()); |
| 436 | if (vectorElementType && vectorElementType.getRank() == 1) { |
| 437 | buildCstData = [&](StringRef data) { |
| 438 | return llvm::ConstantDataVector::getRaw( |
| 439 | data, vectorElementType.getShape().back(), innermostLLVMType); |
| 440 | }; |
| 441 | } else if (!vectorElementType) { |
| 442 | buildCstData = [&](StringRef data) { |
| 443 | return llvm::ConstantDataArray::getRaw(data, type.getShape().back(), |
| 444 | innermostLLVMType); |
| 445 | }; |
| 446 | } |
| 447 | } else if (isa<VectorType>(type)) { |
| 448 | buildCstData = [&](StringRef data) { |
| 449 | return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), |
| 450 | innermostLLVMType); |
| 451 | }; |
| 452 | } |
| 453 | if (!buildCstData) |
| 454 | return nullptr; |
| 455 | |
| 456 | // Create innermost constants and defer to the default constant creation |
| 457 | // mechanism for other dimensions. |
| 458 | SmallVector<llvm::Constant *> constants; |
| 459 | int64_t aggregateSize = denseElementsAttr.getType().getShape().back() * |
| 460 | (innermostLLVMType->getScalarSizeInBits() / 8); |
| 461 | constants.reserve(N: numAggregates); |
| 462 | for (unsigned i = 0; i < numAggregates; ++i) { |
| 463 | StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize, |
| 464 | aggregateSize); |
| 465 | constants.push_back(Elt: buildCstData(data)); |
| 466 | } |
| 467 | |
| 468 | ArrayRef<llvm::Constant *> constantsRef = constants; |
| 469 | return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc); |
| 470 | } |
| 471 | |
| 472 | /// Convert a dense resource elements attribute to an LLVM IR constant using its |
| 473 | /// raw data storage if possible. This supports elements attributes of tensor or |
| 474 | /// vector type and avoids constructing separate objects for individual values |
| 475 | /// of the innermost dimension. Constants for other dimensions are still |
| 476 | /// constructed recursively. Returns nullptr on failure and emits errors at |
| 477 | /// `loc`. |
| 478 | static llvm::Constant *convertDenseResourceElementsAttr( |
| 479 | Location loc, DenseResourceElementsAttr denseResourceAttr, |
| 480 | llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) { |
| 481 | assert(denseResourceAttr && "expected non-null attribute" ); |
| 482 | |
| 483 | llvm::Type *innermostLLVMType = getInnermostElementType(type: llvmType); |
| 484 | if (!llvm::ConstantDataSequential::isElementTypeCompatible( |
| 485 | Ty: innermostLLVMType)) { |
| 486 | emitError(loc, message: "no known conversion for innermost element type" ); |
| 487 | return nullptr; |
| 488 | } |
| 489 | |
| 490 | ShapedType type = denseResourceAttr.getType(); |
| 491 | assert(type.getNumElements() > 0 && "Expected non-empty elements attribute" ); |
| 492 | |
| 493 | AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob(); |
| 494 | if (!blob) { |
| 495 | emitError(loc, message: "resource does not exist" ); |
| 496 | return nullptr; |
| 497 | } |
| 498 | |
| 499 | ArrayRef<char> rawData = blob->getData(); |
| 500 | |
| 501 | // Check that the raw data size matches what is expected for the scalar size. |
| 502 | // TODO: in theory, we could repack the data here to keep constructing from |
| 503 | // raw data. |
| 504 | // TODO: we may also need to consider endianness when cross-compiling to an |
| 505 | // architecture where it is different. |
| 506 | int64_t numElements = denseResourceAttr.getType().getNumElements(); |
| 507 | int64_t elementByteSize = rawData.size() / numElements; |
| 508 | if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) { |
| 509 | emitError(loc, message: "raw data size does not match element type size" ); |
| 510 | return nullptr; |
| 511 | } |
| 512 | |
| 513 | // Compute the shape of all dimensions but the innermost. Note that the |
| 514 | // innermost dimension may be that of the vector element type. |
| 515 | bool hasVectorElementType = isa<VectorType>(type.getElementType()); |
| 516 | int64_t numAggregates = |
| 517 | numElements / (hasVectorElementType |
| 518 | ? 1 |
| 519 | : denseResourceAttr.getType().getShape().back()); |
| 520 | ArrayRef<int64_t> outerShape = type.getShape(); |
| 521 | if (!hasVectorElementType) |
| 522 | outerShape = outerShape.drop_back(); |
| 523 | |
| 524 | // Create a constructor for the innermost constant from a piece of raw data. |
| 525 | std::function<llvm::Constant *(StringRef)> buildCstData; |
| 526 | if (isa<TensorType>(type)) { |
| 527 | auto vectorElementType = dyn_cast<VectorType>(type.getElementType()); |
| 528 | if (vectorElementType && vectorElementType.getRank() == 1) { |
| 529 | buildCstData = [&](StringRef data) { |
| 530 | return llvm::ConstantDataVector::getRaw( |
| 531 | data, vectorElementType.getShape().back(), innermostLLVMType); |
| 532 | }; |
| 533 | } else if (!vectorElementType) { |
| 534 | buildCstData = [&](StringRef data) { |
| 535 | return llvm::ConstantDataArray::getRaw(data, type.getShape().back(), |
| 536 | innermostLLVMType); |
| 537 | }; |
| 538 | } |
| 539 | } else if (isa<VectorType>(type)) { |
| 540 | buildCstData = [&](StringRef data) { |
| 541 | return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), |
| 542 | innermostLLVMType); |
| 543 | }; |
| 544 | } |
| 545 | if (!buildCstData) { |
| 546 | emitError(loc, message: "unsupported dense_resource type" ); |
| 547 | return nullptr; |
| 548 | } |
| 549 | |
| 550 | // Create innermost constants and defer to the default constant creation |
| 551 | // mechanism for other dimensions. |
| 552 | SmallVector<llvm::Constant *> constants; |
| 553 | int64_t aggregateSize = denseResourceAttr.getType().getShape().back() * |
| 554 | (innermostLLVMType->getScalarSizeInBits() / 8); |
| 555 | constants.reserve(N: numAggregates); |
| 556 | for (unsigned i = 0; i < numAggregates; ++i) { |
| 557 | StringRef data(rawData.data() + i * aggregateSize, aggregateSize); |
| 558 | constants.push_back(Elt: buildCstData(data)); |
| 559 | } |
| 560 | |
| 561 | ArrayRef<llvm::Constant *> constantsRef = constants; |
| 562 | return buildSequentialConstant(constants&: constantsRef, shape: outerShape, type: llvmType, loc); |
| 563 | } |
| 564 | |
| 565 | /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. |
| 566 | /// This currently supports integer, floating point, splat and dense element |
| 567 | /// attributes and combinations thereof. Also, an array attribute with two |
| 568 | /// elements is supported to represent a complex constant. In case of error, |
| 569 | /// report it to `loc` and return nullptr. |
| 570 | llvm::Constant *mlir::LLVM::detail::getLLVMConstant( |
| 571 | llvm::Type *llvmType, Attribute attr, Location loc, |
| 572 | const ModuleTranslation &moduleTranslation) { |
| 573 | if (!attr || isa<UndefAttr>(attr)) |
| 574 | return llvm::UndefValue::get(T: llvmType); |
| 575 | if (isa<ZeroAttr>(attr)) |
| 576 | return llvm::Constant::getNullValue(Ty: llvmType); |
| 577 | if (auto *structType = dyn_cast<::llvm::StructType>(Val: llvmType)) { |
| 578 | auto arrayAttr = dyn_cast<ArrayAttr>(attr); |
| 579 | if (!arrayAttr) { |
| 580 | emitError(loc, message: "expected an array attribute for a struct constant" ); |
| 581 | return nullptr; |
| 582 | } |
| 583 | SmallVector<llvm::Constant *> structElements; |
| 584 | structElements.reserve(N: structType->getNumElements()); |
| 585 | for (auto [elemType, elemAttr] : |
| 586 | zip_equal(structType->elements(), arrayAttr)) { |
| 587 | llvm::Constant *element = |
| 588 | getLLVMConstant(elemType, elemAttr, loc, moduleTranslation); |
| 589 | if (!element) |
| 590 | return nullptr; |
| 591 | structElements.push_back(element); |
| 592 | } |
| 593 | return llvm::ConstantStruct::get(T: structType, V: structElements); |
| 594 | } |
| 595 | // For integer types, we allow a mismatch in sizes as the index type in |
| 596 | // MLIR might have a different size than the index type in the LLVM module. |
| 597 | if (auto intAttr = dyn_cast<IntegerAttr>(attr)) |
| 598 | return llvm::ConstantInt::get( |
| 599 | llvmType, |
| 600 | intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); |
| 601 | if (auto floatAttr = dyn_cast<FloatAttr>(attr)) { |
| 602 | const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); |
| 603 | // Special case for 8-bit floats, which are represented by integers due to |
| 604 | // the lack of native fp8 types in LLVM at the moment. Additionally, handle |
| 605 | // targets (like AMDGPU) that don't implement bfloat and convert all bfloats |
| 606 | // to i16. |
| 607 | unsigned floatWidth = APFloat::getSizeInBits(Sem: sem); |
| 608 | if (llvmType->isIntegerTy(Bitwidth: floatWidth)) |
| 609 | return llvm::ConstantInt::get(llvmType, |
| 610 | floatAttr.getValue().bitcastToAPInt()); |
| 611 | if (llvmType != |
| 612 | llvm::Type::getFloatingPointTy(C&: llvmType->getContext(), |
| 613 | S: floatAttr.getValue().getSemantics())) { |
| 614 | emitError(loc, message: "FloatAttr does not match expected type of the constant" ); |
| 615 | return nullptr; |
| 616 | } |
| 617 | return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); |
| 618 | } |
| 619 | if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr)) |
| 620 | return llvm::ConstantExpr::getBitCast( |
| 621 | C: moduleTranslation.lookupFunction(name: funcAttr.getValue()), Ty: llvmType); |
| 622 | if (auto splatAttr = dyn_cast<SplatElementsAttr>(Val&: attr)) { |
| 623 | llvm::Type *elementType; |
| 624 | uint64_t numElements; |
| 625 | bool isScalable = false; |
| 626 | if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: llvmType)) { |
| 627 | elementType = arrayTy->getElementType(); |
| 628 | numElements = arrayTy->getNumElements(); |
| 629 | } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(Val: llvmType)) { |
| 630 | elementType = fVectorTy->getElementType(); |
| 631 | numElements = fVectorTy->getNumElements(); |
| 632 | } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(Val: llvmType)) { |
| 633 | elementType = sVectorTy->getElementType(); |
| 634 | numElements = sVectorTy->getMinNumElements(); |
| 635 | isScalable = true; |
| 636 | } else { |
| 637 | llvm_unreachable("unrecognized constant vector type" ); |
| 638 | } |
| 639 | // Splat value is a scalar. Extract it only if the element type is not |
| 640 | // another sequence type. The recursion terminates because each step removes |
| 641 | // one outer sequential type. |
| 642 | bool elementTypeSequential = |
| 643 | isa<llvm::ArrayType, llvm::VectorType>(Val: elementType); |
| 644 | llvm::Constant *child = getLLVMConstant( |
| 645 | llvmType: elementType, |
| 646 | attr: elementTypeSequential ? splatAttr |
| 647 | : splatAttr.getSplatValue<Attribute>(), |
| 648 | loc, moduleTranslation); |
| 649 | if (!child) |
| 650 | return nullptr; |
| 651 | if (llvmType->isVectorTy()) |
| 652 | return llvm::ConstantVector::getSplat( |
| 653 | EC: llvm::ElementCount::get(MinVal: numElements, /*Scalable=*/isScalable), Elt: child); |
| 654 | if (llvmType->isArrayTy()) { |
| 655 | auto *arrayType = llvm::ArrayType::get(ElementType: elementType, NumElements: numElements); |
| 656 | if (child->isZeroValue()) { |
| 657 | return llvm::ConstantAggregateZero::get(Ty: arrayType); |
| 658 | } else { |
| 659 | if (llvm::ConstantDataSequential::isElementTypeCompatible( |
| 660 | Ty: elementType)) { |
| 661 | // TODO: Handle all compatible types. This code only handles integer. |
| 662 | if (isa<llvm::IntegerType>(Val: elementType)) { |
| 663 | if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(Val: child)) { |
| 664 | if (ci->getBitWidth() == 8) { |
| 665 | SmallVector<int8_t> constants(numElements, ci->getZExtValue()); |
| 666 | return llvm::ConstantDataArray::get(Context&: elementType->getContext(), |
| 667 | Elts&: constants); |
| 668 | } |
| 669 | if (ci->getBitWidth() == 16) { |
| 670 | SmallVector<int16_t> constants(numElements, ci->getZExtValue()); |
| 671 | return llvm::ConstantDataArray::get(Context&: elementType->getContext(), |
| 672 | Elts&: constants); |
| 673 | } |
| 674 | if (ci->getBitWidth() == 32) { |
| 675 | SmallVector<int32_t> constants(numElements, ci->getZExtValue()); |
| 676 | return llvm::ConstantDataArray::get(Context&: elementType->getContext(), |
| 677 | Elts&: constants); |
| 678 | } |
| 679 | if (ci->getBitWidth() == 64) { |
| 680 | SmallVector<int64_t> constants(numElements, ci->getZExtValue()); |
| 681 | return llvm::ConstantDataArray::get(Context&: elementType->getContext(), |
| 682 | Elts&: constants); |
| 683 | } |
| 684 | } |
| 685 | } |
| 686 | } |
| 687 | // std::vector is used here to accomodate large number of elements that |
| 688 | // exceed SmallVector capacity. |
| 689 | std::vector<llvm::Constant *> constants(numElements, child); |
| 690 | return llvm::ConstantArray::get(T: arrayType, V: constants); |
| 691 | } |
| 692 | } |
| 693 | } |
| 694 | |
| 695 | // Try using raw elements data if possible. |
| 696 | if (llvm::Constant *result = |
| 697 | convertDenseElementsAttr(loc, denseElementsAttr: dyn_cast<DenseElementsAttr>(Val&: attr), |
| 698 | llvmType, moduleTranslation)) { |
| 699 | return result; |
| 700 | } |
| 701 | |
| 702 | if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) { |
| 703 | return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType, |
| 704 | moduleTranslation); |
| 705 | } |
| 706 | |
| 707 | // Fall back to element-by-element construction otherwise. |
| 708 | if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) { |
| 709 | assert(elementsAttr.getShapedType().hasStaticShape()); |
| 710 | assert(!elementsAttr.getShapedType().getShape().empty() && |
| 711 | "unexpected empty elements attribute shape" ); |
| 712 | |
| 713 | SmallVector<llvm::Constant *, 8> constants; |
| 714 | constants.reserve(N: elementsAttr.getNumElements()); |
| 715 | llvm::Type *innermostType = getInnermostElementType(type: llvmType); |
| 716 | for (auto n : elementsAttr.getValues<Attribute>()) { |
| 717 | constants.push_back( |
| 718 | getLLVMConstant(innermostType, n, loc, moduleTranslation)); |
| 719 | if (!constants.back()) |
| 720 | return nullptr; |
| 721 | } |
| 722 | ArrayRef<llvm::Constant *> constantsRef = constants; |
| 723 | llvm::Constant *result = buildSequentialConstant( |
| 724 | constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc); |
| 725 | assert(constantsRef.empty() && "did not consume all elemental constants" ); |
| 726 | return result; |
| 727 | } |
| 728 | |
| 729 | if (auto stringAttr = dyn_cast<StringAttr>(attr)) { |
| 730 | return llvm::ConstantDataArray::get( |
| 731 | Context&: moduleTranslation.getLLVMContext(), |
| 732 | Elts: ArrayRef<char>{stringAttr.getValue().data(), |
| 733 | stringAttr.getValue().size()}); |
| 734 | } |
| 735 | |
| 736 | // Handle arrays of structs that cannot be represented as DenseElementsAttr |
| 737 | // in MLIR. |
| 738 | if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { |
| 739 | if (auto *arrayTy = dyn_cast<llvm::ArrayType>(Val: llvmType)) { |
| 740 | llvm::Type *elementType = arrayTy->getElementType(); |
| 741 | Attribute previousElementAttr; |
| 742 | llvm::Constant *elementCst = nullptr; |
| 743 | SmallVector<llvm::Constant *> constants; |
| 744 | constants.reserve(N: arrayTy->getNumElements()); |
| 745 | for (Attribute elementAttr : arrayAttr) { |
| 746 | // Arrays with a single value or with repeating values are quite common. |
| 747 | // Short-circuit the translation when the element value is the same as |
| 748 | // the previous one. |
| 749 | if (!previousElementAttr || previousElementAttr != elementAttr) { |
| 750 | previousElementAttr = elementAttr; |
| 751 | elementCst = |
| 752 | getLLVMConstant(elementType, elementAttr, loc, moduleTranslation); |
| 753 | if (!elementCst) |
| 754 | return nullptr; |
| 755 | } |
| 756 | constants.push_back(elementCst); |
| 757 | } |
| 758 | return llvm::ConstantArray::get(T: arrayTy, V: constants); |
| 759 | } |
| 760 | } |
| 761 | |
| 762 | emitError(loc, message: "unsupported constant value" ); |
| 763 | return nullptr; |
| 764 | } |
| 765 | |
| 766 | ModuleTranslation::ModuleTranslation(Operation *module, |
| 767 | std::unique_ptr<llvm::Module> llvmModule) |
| 768 | : mlirModule(module), llvmModule(std::move(llvmModule)), |
| 769 | debugTranslation( |
| 770 | std::make_unique<DebugTranslation>(args&: module, args&: *this->llvmModule)), |
| 771 | loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>( |
| 772 | args&: *this, args&: *this->llvmModule)), |
| 773 | typeTranslator(this->llvmModule->getContext()), |
| 774 | iface(module->getContext()) { |
| 775 | assert(satisfiesLLVMModule(mlirModule) && |
| 776 | "mlirModule should honor LLVM's module semantics." ); |
| 777 | } |
| 778 | |
| 779 | ModuleTranslation::~ModuleTranslation() { |
| 780 | if (ompBuilder) |
| 781 | ompBuilder->finalize(); |
| 782 | } |
| 783 | |
| 784 | void ModuleTranslation::forgetMapping(Region ®ion) { |
| 785 | SmallVector<Region *> toProcess; |
| 786 | toProcess.push_back(Elt: ®ion); |
| 787 | while (!toProcess.empty()) { |
| 788 | Region *current = toProcess.pop_back_val(); |
| 789 | for (Block &block : *current) { |
| 790 | blockMapping.erase(Val: &block); |
| 791 | for (Value arg : block.getArguments()) |
| 792 | valueMapping.erase(Val: arg); |
| 793 | for (Operation &op : block) { |
| 794 | for (Value value : op.getResults()) |
| 795 | valueMapping.erase(Val: value); |
| 796 | if (op.hasSuccessors()) |
| 797 | branchMapping.erase(Val: &op); |
| 798 | if (isa<LLVM::GlobalOp>(op)) |
| 799 | globalsMapping.erase(Val: &op); |
| 800 | if (isa<LLVM::AliasOp>(op)) |
| 801 | aliasesMapping.erase(Val: &op); |
| 802 | if (isa<LLVM::CallOp>(op)) |
| 803 | callMapping.erase(Val: &op); |
| 804 | llvm::append_range( |
| 805 | C&: toProcess, |
| 806 | R: llvm::map_range(C: op.getRegions(), F: [](Region &r) { return &r; })); |
| 807 | } |
| 808 | } |
| 809 | } |
| 810 | } |
| 811 | |
| 812 | /// Get the SSA value passed to the current block from the terminator operation |
| 813 | /// of its predecessor. |
| 814 | static Value getPHISourceValue(Block *current, Block *pred, |
| 815 | unsigned numArguments, unsigned index) { |
| 816 | Operation &terminator = *pred->getTerminator(); |
| 817 | if (isa<LLVM::BrOp>(terminator)) |
| 818 | return terminator.getOperand(idx: index); |
| 819 | |
| 820 | #ifndef NDEBUG |
| 821 | llvm::SmallPtrSet<Block *, 4> seenSuccessors; |
| 822 | for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) { |
| 823 | Block *successor = terminator.getSuccessor(index: i); |
| 824 | auto branch = cast<BranchOpInterface>(terminator); |
| 825 | SuccessorOperands successorOperands = branch.getSuccessorOperands(i); |
| 826 | assert( |
| 827 | (!seenSuccessors.contains(successor) || successorOperands.empty()) && |
| 828 | "successors with arguments in LLVM branches must be different blocks" ); |
| 829 | seenSuccessors.insert(Ptr: successor); |
| 830 | } |
| 831 | #endif |
| 832 | |
| 833 | // For instructions that branch based on a condition value, we need to take |
| 834 | // the operands for the branch that was taken. |
| 835 | if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) { |
| 836 | // For conditional branches, we take the operands from either the "true" or |
| 837 | // the "false" branch. |
| 838 | return condBranchOp.getSuccessor(0) == current |
| 839 | ? condBranchOp.getTrueDestOperands()[index] |
| 840 | : condBranchOp.getFalseDestOperands()[index]; |
| 841 | } |
| 842 | |
| 843 | if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) { |
| 844 | // For switches, we take the operands from either the default case, or from |
| 845 | // the case branch that was taken. |
| 846 | if (switchOp.getDefaultDestination() == current) |
| 847 | return switchOp.getDefaultOperands()[index]; |
| 848 | for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations())) |
| 849 | if (i.value() == current) |
| 850 | return switchOp.getCaseOperands(i.index())[index]; |
| 851 | } |
| 852 | |
| 853 | if (auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(terminator)) { |
| 854 | // For indirect branches we take operands for each successor. |
| 855 | for (const auto &i : llvm::enumerate(indBrOp->getSuccessors())) { |
| 856 | if (indBrOp->getSuccessor(i.index()) == current) |
| 857 | return indBrOp.getSuccessorOperands(i.index())[index]; |
| 858 | } |
| 859 | } |
| 860 | |
| 861 | if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) { |
| 862 | return invokeOp.getNormalDest() == current |
| 863 | ? invokeOp.getNormalDestOperands()[index] |
| 864 | : invokeOp.getUnwindDestOperands()[index]; |
| 865 | } |
| 866 | |
| 867 | llvm_unreachable( |
| 868 | "only branch, switch or invoke operations can be terminators " |
| 869 | "of a block that has successors" ); |
| 870 | } |
| 871 | |
| 872 | /// Connect the PHI nodes to the results of preceding blocks. |
| 873 | void mlir::LLVM::detail::connectPHINodes(Region ®ion, |
| 874 | const ModuleTranslation &state) { |
| 875 | // Skip the first block, it cannot be branched to and its arguments correspond |
| 876 | // to the arguments of the LLVM function. |
| 877 | for (Block &bb : llvm::drop_begin(RangeOrContainer&: region)) { |
| 878 | llvm::BasicBlock *llvmBB = state.lookupBlock(block: &bb); |
| 879 | auto phis = llvmBB->phis(); |
| 880 | auto numArguments = bb.getNumArguments(); |
| 881 | assert(numArguments == std::distance(phis.begin(), phis.end())); |
| 882 | for (auto [index, phiNode] : llvm::enumerate(First&: phis)) { |
| 883 | for (auto *pred : bb.getPredecessors()) { |
| 884 | // Find the LLVM IR block that contains the converted terminator |
| 885 | // instruction and use it in the PHI node. Note that this block is not |
| 886 | // necessarily the same as state.lookupBlock(pred), some operations |
| 887 | // (in particular, OpenMP operations using OpenMPIRBuilder) may have |
| 888 | // split the blocks. |
| 889 | llvm::Instruction *terminator = |
| 890 | state.lookupBranch(op: pred->getTerminator()); |
| 891 | assert(terminator && "missing the mapping for a terminator" ); |
| 892 | phiNode.addIncoming(V: state.lookupValue(value: getPHISourceValue( |
| 893 | current: &bb, pred, numArguments, index)), |
| 894 | BB: terminator->getParent()); |
| 895 | } |
| 896 | } |
| 897 | } |
| 898 | } |
| 899 | |
| 900 | llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( |
| 901 | llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, |
| 902 | ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) { |
| 903 | llvm::Module *module = builder.GetInsertBlock()->getModule(); |
| 904 | llvm::Function *fn = |
| 905 | llvm::Intrinsic::getOrInsertDeclaration(M: module, id: intrinsic, Tys: tys); |
| 906 | return builder.CreateCall(Callee: fn, Args: args); |
| 907 | } |
| 908 | |
| 909 | llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( |
| 910 | llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation, |
| 911 | Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults, |
| 912 | ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands, |
| 913 | ArrayRef<unsigned> immArgPositions, |
| 914 | ArrayRef<StringLiteral> immArgAttrNames) { |
| 915 | assert(immArgPositions.size() == immArgAttrNames.size() && |
| 916 | "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal " |
| 917 | "length" ); |
| 918 | |
| 919 | SmallVector<llvm::OperandBundleDef> opBundles; |
| 920 | size_t numOpBundleOperands = 0; |
| 921 | auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>( |
| 922 | intrOp->getAttr(LLVMDialect::getOpBundleSizesAttrName())); |
| 923 | auto opBundleTagsAttr = cast_if_present<ArrayAttr>( |
| 924 | intrOp->getAttr(LLVMDialect::getOpBundleTagsAttrName())); |
| 925 | |
| 926 | if (opBundleSizesAttr && opBundleTagsAttr) { |
| 927 | ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef(); |
| 928 | assert(opBundleSizes.size() == opBundleTagsAttr.size() && |
| 929 | "operand bundles and tags do not match" ); |
| 930 | |
| 931 | numOpBundleOperands = |
| 932 | std::accumulate(first: opBundleSizes.begin(), last: opBundleSizes.end(), init: size_t(0)); |
| 933 | assert(numOpBundleOperands <= intrOp->getNumOperands() && |
| 934 | "operand bundle operands is more than the number of operands" ); |
| 935 | |
| 936 | ValueRange operands = intrOp->getOperands().take_back(n: numOpBundleOperands); |
| 937 | size_t nextOperandIdx = 0; |
| 938 | opBundles.reserve(N: opBundleSizesAttr.size()); |
| 939 | |
| 940 | for (auto [opBundleTagAttr, bundleSize] : |
| 941 | llvm::zip(opBundleTagsAttr, opBundleSizes)) { |
| 942 | auto bundleTag = cast<StringAttr>(opBundleTagAttr).str(); |
| 943 | auto bundleOperands = moduleTranslation.lookupValues( |
| 944 | operands.slice(nextOperandIdx, bundleSize)); |
| 945 | opBundles.emplace_back(std::move(bundleTag), std::move(bundleOperands)); |
| 946 | nextOperandIdx += bundleSize; |
| 947 | } |
| 948 | } |
| 949 | |
| 950 | // Map operands and attributes to LLVM values. |
| 951 | auto opOperands = intrOp->getOperands().drop_back(n: numOpBundleOperands); |
| 952 | auto operands = moduleTranslation.lookupValues(values: opOperands); |
| 953 | SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size()); |
| 954 | for (auto [immArgPos, immArgName] : |
| 955 | llvm::zip(t&: immArgPositions, u&: immArgAttrNames)) { |
| 956 | auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName)); |
| 957 | assert(attr.getType().isIntOrFloat() && "expected int or float immarg" ); |
| 958 | auto *type = moduleTranslation.convertType(type: attr.getType()); |
| 959 | args[immArgPos] = LLVM::detail::getLLVMConstant( |
| 960 | llvmType: type, attr: attr, loc: intrOp->getLoc(), moduleTranslation); |
| 961 | } |
| 962 | unsigned opArg = 0; |
| 963 | for (auto &arg : args) { |
| 964 | if (!arg) |
| 965 | arg = operands[opArg++]; |
| 966 | } |
| 967 | |
| 968 | // Resolve overloaded intrinsic declaration. |
| 969 | SmallVector<llvm::Type *> overloadedTypes; |
| 970 | for (unsigned overloadedResultIdx : overloadedResults) { |
| 971 | if (numResults > 1) { |
| 972 | // More than one result is mapped to an LLVM struct. |
| 973 | overloadedTypes.push_back(moduleTranslation.convertType( |
| 974 | llvm::cast<LLVM::LLVMStructType>(intrOp->getResult(0).getType()) |
| 975 | .getBody()[overloadedResultIdx])); |
| 976 | } else { |
| 977 | overloadedTypes.push_back( |
| 978 | Elt: moduleTranslation.convertType(type: intrOp->getResult(idx: 0).getType())); |
| 979 | } |
| 980 | } |
| 981 | for (unsigned overloadedOperandIdx : overloadedOperands) |
| 982 | overloadedTypes.push_back(Elt: args[overloadedOperandIdx]->getType()); |
| 983 | llvm::Module *module = builder.GetInsertBlock()->getModule(); |
| 984 | llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration( |
| 985 | M: module, id: intrinsic, Tys: overloadedTypes); |
| 986 | |
| 987 | return builder.CreateCall(Callee: llvmIntr, Args: args, OpBundles: opBundles); |
| 988 | } |
| 989 | |
| 990 | /// Given a single MLIR operation, create the corresponding LLVM IR operation |
| 991 | /// using the `builder`. |
| 992 | LogicalResult ModuleTranslation::convertOperation(Operation &op, |
| 993 | llvm::IRBuilderBase &builder, |
| 994 | bool recordInsertions) { |
| 995 | const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(obj: &op); |
| 996 | if (!opIface) |
| 997 | return op.emitError(message: "cannot be converted to LLVM IR: missing " |
| 998 | "`LLVMTranslationDialectInterface` registration for " |
| 999 | "dialect for op: " ) |
| 1000 | << op.getName(); |
| 1001 | |
| 1002 | InstructionCapturingInserter::CollectionScope scope(builder, |
| 1003 | recordInsertions); |
| 1004 | if (failed(Result: opIface->convertOperation(op: &op, builder, moduleTranslation&: *this))) |
| 1005 | return op.emitError(message: "LLVM Translation failed for operation: " ) |
| 1006 | << op.getName(); |
| 1007 | |
| 1008 | return convertDialectAttributes(op: &op, instructions: scope.getCapturedInstructions()); |
| 1009 | } |
| 1010 | |
| 1011 | /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes |
| 1012 | /// to define values corresponding to the MLIR block arguments. These nodes |
| 1013 | /// are not connected to the source basic blocks, which may not exist yet. Uses |
| 1014 | /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have |
| 1015 | /// been created for `bb` and included in the block mapping. Inserts new |
| 1016 | /// instructions at the end of the block and leaves `builder` in a state |
| 1017 | /// suitable for further insertion into the end of the block. |
| 1018 | LogicalResult ModuleTranslation::convertBlockImpl(Block &bb, |
| 1019 | bool ignoreArguments, |
| 1020 | llvm::IRBuilderBase &builder, |
| 1021 | bool recordInsertions) { |
| 1022 | builder.SetInsertPoint(lookupBlock(block: &bb)); |
| 1023 | auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); |
| 1024 | |
| 1025 | // Before traversing operations, make block arguments available through |
| 1026 | // value remapping and PHI nodes, but do not add incoming edges for the PHI |
| 1027 | // nodes just yet: those values may be defined by this or following blocks. |
| 1028 | // This step is omitted if "ignoreArguments" is set. The arguments of the |
| 1029 | // first block have been already made available through the remapping of |
| 1030 | // LLVM function arguments. |
| 1031 | if (!ignoreArguments) { |
| 1032 | auto predecessors = bb.getPredecessors(); |
| 1033 | unsigned numPredecessors = |
| 1034 | std::distance(first: predecessors.begin(), last: predecessors.end()); |
| 1035 | for (auto arg : bb.getArguments()) { |
| 1036 | auto wrappedType = arg.getType(); |
| 1037 | if (!isCompatibleType(type: wrappedType)) |
| 1038 | return emitError(loc: bb.front().getLoc(), |
| 1039 | message: "block argument does not have an LLVM type" ); |
| 1040 | builder.SetCurrentDebugLocation( |
| 1041 | debugTranslation->translateLoc(loc: arg.getLoc(), scope: subprogram)); |
| 1042 | llvm::Type *type = convertType(type: wrappedType); |
| 1043 | llvm::PHINode *phi = builder.CreatePHI(Ty: type, NumReservedValues: numPredecessors); |
| 1044 | mapValue(mlir: arg, llvm: phi); |
| 1045 | } |
| 1046 | } |
| 1047 | |
| 1048 | // Traverse operations. |
| 1049 | for (auto &op : bb) { |
| 1050 | // Set the current debug location within the builder. |
| 1051 | builder.SetCurrentDebugLocation( |
| 1052 | debugTranslation->translateLoc(loc: op.getLoc(), scope: subprogram)); |
| 1053 | |
| 1054 | if (failed(Result: convertOperation(op, builder, recordInsertions))) |
| 1055 | return failure(); |
| 1056 | |
| 1057 | // Set the branch weight metadata on the translated instruction. |
| 1058 | if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) |
| 1059 | setBranchWeightsMetadata(iface); |
| 1060 | } |
| 1061 | |
| 1062 | return success(); |
| 1063 | } |
| 1064 | |
| 1065 | /// A helper method to get the single Block in an operation honoring LLVM's |
| 1066 | /// module requirements. |
| 1067 | static Block &getModuleBody(Operation *module) { |
| 1068 | return module->getRegion(index: 0).front(); |
| 1069 | } |
| 1070 | |
| 1071 | /// A helper method to decide if a constant must not be set as a global variable |
| 1072 | /// initializer. For an external linkage variable, the variable with an |
| 1073 | /// initializer is considered externally visible and defined in this module, the |
| 1074 | /// variable without an initializer is externally available and is defined |
| 1075 | /// elsewhere. |
| 1076 | static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, |
| 1077 | llvm::Constant *cst) { |
| 1078 | return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) || |
| 1079 | linkage == llvm::GlobalVariable::ExternalWeakLinkage; |
| 1080 | } |
| 1081 | |
| 1082 | /// Sets the runtime preemption specifier of `gv` to dso_local if |
| 1083 | /// `dsoLocalRequested` is true, otherwise it is left unchanged. |
| 1084 | static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, |
| 1085 | llvm::GlobalValue *gv) { |
| 1086 | if (dsoLocalRequested) |
| 1087 | gv->setDSOLocal(true); |
| 1088 | } |
| 1089 | |
| 1090 | LogicalResult ModuleTranslation::convertGlobalsAndAliases() { |
| 1091 | // Mapping from compile unit to its respective set of global variables. |
| 1092 | DenseMap<llvm::DICompileUnit *, SmallVector<llvm::Metadata *>> allGVars; |
| 1093 | |
| 1094 | // First, create all global variables and global aliases in LLVM IR. A global |
| 1095 | // or alias body may refer to another global/alias or itself, so all the |
| 1096 | // mapping needs to happen prior to body conversion. |
| 1097 | |
| 1098 | // Create all llvm::GlobalVariable |
| 1099 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { |
| 1100 | llvm::Type *type = convertType(op.getType()); |
| 1101 | llvm::Constant *cst = nullptr; |
| 1102 | if (op.getValueOrNull()) { |
| 1103 | // String attributes are treated separately because they cannot appear as |
| 1104 | // in-function constants and are thus not supported by getLLVMConstant. |
| 1105 | if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) { |
| 1106 | cst = llvm::ConstantDataArray::getString( |
| 1107 | llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); |
| 1108 | type = cst->getType(); |
| 1109 | } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(), |
| 1110 | *this))) { |
| 1111 | return failure(); |
| 1112 | } |
| 1113 | } |
| 1114 | |
| 1115 | auto linkage = convertLinkageToLLVM(op.getLinkage()); |
| 1116 | |
| 1117 | // LLVM IR requires constant with linkage other than external or weak |
| 1118 | // external to have initializers. If MLIR does not provide an initializer, |
| 1119 | // default to undef. |
| 1120 | bool dropInitializer = shouldDropGlobalInitializer(linkage, cst); |
| 1121 | if (!dropInitializer && !cst) |
| 1122 | cst = llvm::UndefValue::get(type); |
| 1123 | else if (dropInitializer && cst) |
| 1124 | cst = nullptr; |
| 1125 | |
| 1126 | auto *var = new llvm::GlobalVariable( |
| 1127 | *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(), |
| 1128 | /*InsertBefore=*/nullptr, |
| 1129 | op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel |
| 1130 | : llvm::GlobalValue::NotThreadLocal, |
| 1131 | op.getAddrSpace(), op.getExternallyInitialized()); |
| 1132 | |
| 1133 | if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) { |
| 1134 | auto selectorOp = cast<ComdatSelectorOp>( |
| 1135 | SymbolTable::lookupNearestSymbolFrom(op, *comdat)); |
| 1136 | var->setComdat(comdatMapping.lookup(selectorOp)); |
| 1137 | } |
| 1138 | |
| 1139 | if (op.getUnnamedAddr().has_value()) |
| 1140 | var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); |
| 1141 | |
| 1142 | if (op.getSection().has_value()) |
| 1143 | var->setSection(*op.getSection()); |
| 1144 | |
| 1145 | addRuntimePreemptionSpecifier(op.getDsoLocal(), var); |
| 1146 | |
| 1147 | std::optional<uint64_t> alignment = op.getAlignment(); |
| 1148 | if (alignment.has_value()) |
| 1149 | var->setAlignment(llvm::MaybeAlign(alignment.value())); |
| 1150 | |
| 1151 | var->setVisibility(convertVisibilityToLLVM(op.getVisibility_())); |
| 1152 | |
| 1153 | globalsMapping.try_emplace(op, var); |
| 1154 | |
| 1155 | // Add debug information if present. |
| 1156 | if (op.getDbgExprs()) { |
| 1157 | for (auto exprAttr : |
| 1158 | op.getDbgExprs()->getAsRange<DIGlobalVariableExpressionAttr>()) { |
| 1159 | llvm::DIGlobalVariableExpression *diGlobalExpr = |
| 1160 | debugTranslation->translateGlobalVariableExpression(exprAttr); |
| 1161 | llvm::DIGlobalVariable *diGlobalVar = diGlobalExpr->getVariable(); |
| 1162 | var->addDebugInfo(diGlobalExpr); |
| 1163 | |
| 1164 | // There is no `globals` field in DICompileUnitAttr which can be |
| 1165 | // directly assigned to DICompileUnit. We have to build the list by |
| 1166 | // looking at the dbgExpr of all the GlobalOps. The scope of the |
| 1167 | // variable is used to get the DICompileUnit in which to add it. But |
| 1168 | // there are cases where the scope of a global does not directly point |
| 1169 | // to the DICompileUnit and we have to do a bit more work to get to |
| 1170 | // it. Some of those cases are: |
| 1171 | // |
| 1172 | // 1. For the languages that support modules, the scope hierarchy can |
| 1173 | // be variable -> DIModule -> DICompileUnit |
| 1174 | // |
| 1175 | // 2. For the Fortran common block variable, the scope hierarchy can |
| 1176 | // be variable -> DICommonBlock -> DISubprogram -> DICompileUnit |
| 1177 | // |
| 1178 | // 3. For entities like static local variables in C or variable with |
| 1179 | // SAVE attribute in Fortran, the scope hierarchy can be |
| 1180 | // variable -> DISubprogram -> DICompileUnit |
| 1181 | llvm::DIScope *scope = diGlobalVar->getScope(); |
| 1182 | if (auto *mod = dyn_cast_if_present<llvm::DIModule>(scope)) |
| 1183 | scope = mod->getScope(); |
| 1184 | else if (auto *cb = dyn_cast_if_present<llvm::DICommonBlock>(scope)) { |
| 1185 | if (auto *sp = |
| 1186 | dyn_cast_if_present<llvm::DISubprogram>(cb->getScope())) |
| 1187 | scope = sp->getUnit(); |
| 1188 | } else if (auto *sp = dyn_cast_if_present<llvm::DISubprogram>(scope)) |
| 1189 | scope = sp->getUnit(); |
| 1190 | |
| 1191 | // Get the compile unit (scope) of the the global variable. |
| 1192 | if (llvm::DICompileUnit *compileUnit = |
| 1193 | dyn_cast_if_present<llvm::DICompileUnit>(scope)) { |
| 1194 | // Update the compile unit with this incoming global variable |
| 1195 | // expression during the finalizing step later. |
| 1196 | allGVars[compileUnit].push_back(diGlobalExpr); |
| 1197 | } |
| 1198 | } |
| 1199 | } |
| 1200 | } |
| 1201 | |
| 1202 | // Create all llvm::GlobalAlias |
| 1203 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::AliasOp>()) { |
| 1204 | llvm::Type *type = convertType(op.getType()); |
| 1205 | llvm::Constant *cst = nullptr; |
| 1206 | llvm::GlobalValue::LinkageTypes linkage = |
| 1207 | convertLinkageToLLVM(op.getLinkage()); |
| 1208 | llvm::Module &llvmMod = *llvmModule; |
| 1209 | |
| 1210 | // Note address space and aliasee info isn't set just yet. |
| 1211 | llvm::GlobalAlias *var = llvm::GlobalAlias::create( |
| 1212 | type, op.getAddrSpace(), linkage, op.getSymName(), /*placeholder*/ cst, |
| 1213 | &llvmMod); |
| 1214 | |
| 1215 | var->setThreadLocalMode(op.getThreadLocal_() |
| 1216 | ? llvm::GlobalAlias::GeneralDynamicTLSModel |
| 1217 | : llvm::GlobalAlias::NotThreadLocal); |
| 1218 | |
| 1219 | // Note there is no need to setup the comdat because GlobalAlias calls into |
| 1220 | // the aliasee comdat information automatically. |
| 1221 | |
| 1222 | if (op.getUnnamedAddr().has_value()) |
| 1223 | var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr())); |
| 1224 | |
| 1225 | var->setVisibility(convertVisibilityToLLVM(op.getVisibility_())); |
| 1226 | |
| 1227 | aliasesMapping.try_emplace(op, var); |
| 1228 | } |
| 1229 | |
| 1230 | // Convert global variable bodies. |
| 1231 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { |
| 1232 | if (Block *initializer = op.getInitializerBlock()) { |
| 1233 | llvm::IRBuilder<llvm::TargetFolder> builder( |
| 1234 | llvmModule->getContext(), |
| 1235 | llvm::TargetFolder(llvmModule->getDataLayout())); |
| 1236 | |
| 1237 | [[maybe_unused]] int numConstantsHit = 0; |
| 1238 | [[maybe_unused]] int numConstantsErased = 0; |
| 1239 | DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap; |
| 1240 | |
| 1241 | for (auto &op : initializer->without_terminator()) { |
| 1242 | if (failed(convertOperation(op, builder))) |
| 1243 | return emitError(op.getLoc(), "fail to convert global initializer" ); |
| 1244 | auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0))); |
| 1245 | if (!cst) |
| 1246 | return emitError(op.getLoc(), "unemittable constant value" ); |
| 1247 | |
| 1248 | // When emitting an LLVM constant, a new constant is created and the old |
| 1249 | // constant may become dangling and take space. We should remove the |
| 1250 | // dangling constants to avoid memory explosion especially for constant |
| 1251 | // arrays whose number of elements is large. |
| 1252 | // Because multiple operations may refer to the same constant, we need |
| 1253 | // to count the number of uses of each constant array and remove it only |
| 1254 | // when the count becomes zero. |
| 1255 | if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) { |
| 1256 | numConstantsHit++; |
| 1257 | Value result = op.getResult(0); |
| 1258 | int numUsers = std::distance(result.use_begin(), result.use_end()); |
| 1259 | auto [iterator, inserted] = |
| 1260 | constantAggregateUseMap.try_emplace(agg, numUsers); |
| 1261 | if (!inserted) { |
| 1262 | // Key already exists, update the value |
| 1263 | iterator->second += numUsers; |
| 1264 | } |
| 1265 | } |
| 1266 | // Scan the operands of the operation to decrement the use count of |
| 1267 | // constants. Erase the constant if the use count becomes zero. |
| 1268 | for (Value v : op.getOperands()) { |
| 1269 | auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v)); |
| 1270 | if (!cst) |
| 1271 | continue; |
| 1272 | auto iter = constantAggregateUseMap.find(cst); |
| 1273 | assert(iter != constantAggregateUseMap.end() && "constant not found" ); |
| 1274 | iter->second--; |
| 1275 | if (iter->second == 0) { |
| 1276 | // NOTE: cannot call removeDeadConstantUsers() here because it |
| 1277 | // may remove the constant which has uses not be converted yet. |
| 1278 | if (cst->user_empty()) { |
| 1279 | cst->destroyConstant(); |
| 1280 | numConstantsErased++; |
| 1281 | } |
| 1282 | constantAggregateUseMap.erase(iter); |
| 1283 | } |
| 1284 | } |
| 1285 | } |
| 1286 | |
| 1287 | ReturnOp ret = cast<ReturnOp>(initializer->getTerminator()); |
| 1288 | llvm::Constant *cst = |
| 1289 | cast<llvm::Constant>(lookupValue(ret.getOperand(0))); |
| 1290 | auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op)); |
| 1291 | if (!shouldDropGlobalInitializer(global->getLinkage(), cst)) |
| 1292 | global->setInitializer(cst); |
| 1293 | |
| 1294 | // Try to remove the dangling constants again after all operations are |
| 1295 | // converted. |
| 1296 | for (auto it : constantAggregateUseMap) { |
| 1297 | auto cst = it.first; |
| 1298 | cst->removeDeadConstantUsers(); |
| 1299 | if (cst->user_empty()) { |
| 1300 | cst->destroyConstant(); |
| 1301 | numConstantsErased++; |
| 1302 | } |
| 1303 | } |
| 1304 | |
| 1305 | LLVM_DEBUG(llvm::dbgs() |
| 1306 | << "Convert initializer for " << op.getName() << "\n" ; |
| 1307 | llvm::dbgs() << numConstantsHit << " new constants hit\n" ; |
| 1308 | llvm::dbgs() |
| 1309 | << numConstantsErased << " dangling constants erased\n" ;); |
| 1310 | } |
| 1311 | } |
| 1312 | |
| 1313 | // Convert llvm.mlir.global_ctors and dtors. |
| 1314 | for (Operation &op : getModuleBody(module: mlirModule)) { |
| 1315 | auto ctorOp = dyn_cast<GlobalCtorsOp>(op); |
| 1316 | auto dtorOp = dyn_cast<GlobalDtorsOp>(op); |
| 1317 | if (!ctorOp && !dtorOp) |
| 1318 | continue; |
| 1319 | |
| 1320 | // The empty / zero initialized version of llvm.global_(c|d)tors cannot be |
| 1321 | // handled by appendGlobalFn logic below, which just ignores empty (c|d)tor |
| 1322 | // lists. Make sure it gets emitted. |
| 1323 | if ((ctorOp && ctorOp.getCtors().empty()) || |
| 1324 | (dtorOp && dtorOp.getDtors().empty())) { |
| 1325 | llvm::IRBuilder<llvm::TargetFolder> builder( |
| 1326 | llvmModule->getContext(), |
| 1327 | llvm::TargetFolder(llvmModule->getDataLayout())); |
| 1328 | llvm::Type *eltTy = llvm::StructType::get( |
| 1329 | elt1: builder.getInt32Ty(), elts: builder.getPtrTy(), elts: builder.getPtrTy()); |
| 1330 | llvm::ArrayType *at = llvm::ArrayType::get(ElementType: eltTy, NumElements: 0); |
| 1331 | llvm::Constant *zeroInit = llvm::Constant::getNullValue(Ty: at); |
| 1332 | (void)new llvm::GlobalVariable( |
| 1333 | *llvmModule, zeroInit->getType(), false, |
| 1334 | llvm::GlobalValue::AppendingLinkage, zeroInit, |
| 1335 | ctorOp ? "llvm.global_ctors" : "llvm.global_dtors" ); |
| 1336 | } else { |
| 1337 | auto range = ctorOp |
| 1338 | ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities()) |
| 1339 | : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities()); |
| 1340 | auto appendGlobalFn = |
| 1341 | ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors; |
| 1342 | for (const auto &[sym, prio] : range) { |
| 1343 | llvm::Function *f = |
| 1344 | lookupFunction(cast<FlatSymbolRefAttr>(sym).getValue()); |
| 1345 | appendGlobalFn(*llvmModule, f, cast<IntegerAttr>(prio).getInt(), |
| 1346 | /*Data=*/nullptr); |
| 1347 | } |
| 1348 | } |
| 1349 | } |
| 1350 | |
| 1351 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) |
| 1352 | if (failed(convertDialectAttributes(op, {}))) |
| 1353 | return failure(); |
| 1354 | |
| 1355 | // Finally, update the compile units their respective sets of global variables |
| 1356 | // created earlier. |
| 1357 | for (const auto &[compileUnit, globals] : allGVars) { |
| 1358 | compileUnit->replaceGlobalVariables( |
| 1359 | N: llvm::MDTuple::get(Context&: getLLVMContext(), MDs: globals)); |
| 1360 | } |
| 1361 | |
| 1362 | // Convert global alias bodies. |
| 1363 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::AliasOp>()) { |
| 1364 | Block &initializer = op.getInitializerBlock(); |
| 1365 | llvm::IRBuilder<llvm::TargetFolder> builder( |
| 1366 | llvmModule->getContext(), |
| 1367 | llvm::TargetFolder(llvmModule->getDataLayout())); |
| 1368 | |
| 1369 | for (mlir::Operation &op : initializer.without_terminator()) { |
| 1370 | if (failed(convertOperation(op, builder))) |
| 1371 | return emitError(op.getLoc(), "fail to convert alias initializer" ); |
| 1372 | if (!isa<llvm::Constant>(lookupValue(op.getResult(0)))) |
| 1373 | return emitError(op.getLoc(), "unemittable constant value" ); |
| 1374 | } |
| 1375 | |
| 1376 | auto ret = cast<ReturnOp>(initializer.getTerminator()); |
| 1377 | auto *cst = cast<llvm::Constant>(lookupValue(ret.getOperand(0))); |
| 1378 | assert(aliasesMapping.count(op)); |
| 1379 | auto *alias = cast<llvm::GlobalAlias>(aliasesMapping[op]); |
| 1380 | alias->setAliasee(cst); |
| 1381 | } |
| 1382 | |
| 1383 | for (auto op : getModuleBody(mlirModule).getOps<LLVM::AliasOp>()) |
| 1384 | if (failed(convertDialectAttributes(op, {}))) |
| 1385 | return failure(); |
| 1386 | |
| 1387 | return success(); |
| 1388 | } |
| 1389 | |
| 1390 | /// Attempts to add an attribute identified by `key`, optionally with the given |
| 1391 | /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the |
| 1392 | /// attribute has a kind known to LLVM IR, create the attribute of this kind, |
| 1393 | /// otherwise keep it as a string attribute. Performs additional checks for |
| 1394 | /// attributes known to have or not have a value in order to avoid assertions |
| 1395 | /// inside LLVM upon construction. |
| 1396 | static LogicalResult checkedAddLLVMFnAttribute(Location loc, |
| 1397 | llvm::Function *llvmFunc, |
| 1398 | StringRef key, |
| 1399 | StringRef value = StringRef()) { |
| 1400 | auto kind = llvm::Attribute::getAttrKindFromName(AttrName: key); |
| 1401 | if (kind == llvm::Attribute::None) { |
| 1402 | llvmFunc->addFnAttr(Kind: key, Val: value); |
| 1403 | return success(); |
| 1404 | } |
| 1405 | |
| 1406 | if (llvm::Attribute::isIntAttrKind(Kind: kind)) { |
| 1407 | if (value.empty()) |
| 1408 | return emitError(loc) << "LLVM attribute '" << key << "' expects a value" ; |
| 1409 | |
| 1410 | int64_t result; |
| 1411 | if (!value.getAsInteger(/*Radix=*/0, Result&: result)) |
| 1412 | llvmFunc->addFnAttr( |
| 1413 | Attr: llvm::Attribute::get(Context&: llvmFunc->getContext(), Kind: kind, Val: result)); |
| 1414 | else |
| 1415 | llvmFunc->addFnAttr(Kind: key, Val: value); |
| 1416 | return success(); |
| 1417 | } |
| 1418 | |
| 1419 | if (!value.empty()) |
| 1420 | return emitError(loc) << "LLVM attribute '" << key |
| 1421 | << "' does not expect a value, found '" << value |
| 1422 | << "'" ; |
| 1423 | |
| 1424 | llvmFunc->addFnAttr(Kind: kind); |
| 1425 | return success(); |
| 1426 | } |
| 1427 | |
| 1428 | /// Return a representation of `value` as metadata. |
| 1429 | static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context, |
| 1430 | const llvm::APInt &value) { |
| 1431 | llvm::Constant *constant = llvm::ConstantInt::get(Context&: context, V: value); |
| 1432 | return llvm::ConstantAsMetadata::get(C: constant); |
| 1433 | } |
| 1434 | |
| 1435 | /// Return a representation of `value` as an MDNode. |
| 1436 | static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context, |
| 1437 | const llvm::APInt &value) { |
| 1438 | return llvm::MDNode::get(Context&: context, MDs: convertIntegerToMetadata(context, value)); |
| 1439 | } |
| 1440 | |
| 1441 | /// Return an MDNode encoding `vec_type_hint` metadata. |
| 1442 | static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context, |
| 1443 | llvm::Type *type, |
| 1444 | bool isSigned) { |
| 1445 | llvm::Metadata *typeMD = |
| 1446 | llvm::ConstantAsMetadata::get(C: llvm::UndefValue::get(T: type)); |
| 1447 | llvm::Metadata *isSignedMD = |
| 1448 | convertIntegerToMetadata(context, value: llvm::APInt(32, isSigned ? 1 : 0)); |
| 1449 | return llvm::MDNode::get(Context&: context, MDs: {typeMD, isSignedMD}); |
| 1450 | } |
| 1451 | |
| 1452 | /// Return an MDNode with a tuple given by the values in `values`. |
| 1453 | static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context, |
| 1454 | ArrayRef<int32_t> values) { |
| 1455 | SmallVector<llvm::Metadata *> mdValues; |
| 1456 | llvm::transform( |
| 1457 | Range&: values, d_first: std::back_inserter(x&: mdValues), F: [&context](int32_t value) { |
| 1458 | return convertIntegerToMetadata(context, value: llvm::APInt(32, value)); |
| 1459 | }); |
| 1460 | return llvm::MDNode::get(Context&: context, MDs: mdValues); |
| 1461 | } |
| 1462 | |
| 1463 | /// Attaches the attributes listed in the given array attribute to `llvmFunc`. |
| 1464 | /// Reports error to `loc` if any and returns immediately. Expects `attributes` |
| 1465 | /// to be an array attribute containing either string attributes, treated as |
| 1466 | /// value-less LLVM attributes, or array attributes containing two string |
| 1467 | /// attributes, with the first string being the name of the corresponding LLVM |
| 1468 | /// attribute and the second string beings its value. Note that even integer |
| 1469 | /// attributes are expected to have their values expressed as strings. |
| 1470 | static LogicalResult |
| 1471 | forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes, |
| 1472 | llvm::Function *llvmFunc) { |
| 1473 | if (!attributes) |
| 1474 | return success(); |
| 1475 | |
| 1476 | for (Attribute attr : *attributes) { |
| 1477 | if (auto stringAttr = dyn_cast<StringAttr>(attr)) { |
| 1478 | if (failed( |
| 1479 | checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) |
| 1480 | return failure(); |
| 1481 | continue; |
| 1482 | } |
| 1483 | |
| 1484 | auto arrayAttr = dyn_cast<ArrayAttr>(attr); |
| 1485 | if (!arrayAttr || arrayAttr.size() != 2) |
| 1486 | return emitError(loc) |
| 1487 | << "expected 'passthrough' to contain string or array attributes" ; |
| 1488 | |
| 1489 | auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]); |
| 1490 | auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]); |
| 1491 | if (!keyAttr || !valueAttr) |
| 1492 | return emitError(loc) |
| 1493 | << "expected arrays within 'passthrough' to contain two strings" ; |
| 1494 | |
| 1495 | if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(), |
| 1496 | valueAttr.getValue()))) |
| 1497 | return failure(); |
| 1498 | } |
| 1499 | return success(); |
| 1500 | } |
| 1501 | |
| 1502 | LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { |
| 1503 | // Clear the block, branch value mappings, they are only relevant within one |
| 1504 | // function. |
| 1505 | blockMapping.clear(); |
| 1506 | valueMapping.clear(); |
| 1507 | branchMapping.clear(); |
| 1508 | llvm::Function *llvmFunc = lookupFunction(name: func.getName()); |
| 1509 | |
| 1510 | // Add function arguments to the value remapping table. |
| 1511 | for (auto [mlirArg, llvmArg] : |
| 1512 | llvm::zip(func.getArguments(), llvmFunc->args())) |
| 1513 | mapValue(mlirArg, &llvmArg); |
| 1514 | |
| 1515 | // Check the personality and set it. |
| 1516 | if (func.getPersonality()) { |
| 1517 | llvm::Type *ty = llvm::PointerType::getUnqual(C&: llvmFunc->getContext()); |
| 1518 | if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(), |
| 1519 | func.getLoc(), *this)) |
| 1520 | llvmFunc->setPersonalityFn(pfunc); |
| 1521 | } |
| 1522 | |
| 1523 | if (std::optional<StringRef> section = func.getSection()) |
| 1524 | llvmFunc->setSection(*section); |
| 1525 | |
| 1526 | if (func.getArmStreaming()) |
| 1527 | llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_enabled" ); |
| 1528 | else if (func.getArmLocallyStreaming()) |
| 1529 | llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_body" ); |
| 1530 | else if (func.getArmStreamingCompatible()) |
| 1531 | llvmFunc->addFnAttr(Kind: "aarch64_pstate_sm_compatible" ); |
| 1532 | |
| 1533 | if (func.getArmNewZa()) |
| 1534 | llvmFunc->addFnAttr(Kind: "aarch64_new_za" ); |
| 1535 | else if (func.getArmInZa()) |
| 1536 | llvmFunc->addFnAttr(Kind: "aarch64_in_za" ); |
| 1537 | else if (func.getArmOutZa()) |
| 1538 | llvmFunc->addFnAttr(Kind: "aarch64_out_za" ); |
| 1539 | else if (func.getArmInoutZa()) |
| 1540 | llvmFunc->addFnAttr(Kind: "aarch64_inout_za" ); |
| 1541 | else if (func.getArmPreservesZa()) |
| 1542 | llvmFunc->addFnAttr(Kind: "aarch64_preserves_za" ); |
| 1543 | |
| 1544 | if (auto targetCpu = func.getTargetCpu()) |
| 1545 | llvmFunc->addFnAttr("target-cpu" , *targetCpu); |
| 1546 | |
| 1547 | if (auto tuneCpu = func.getTuneCpu()) |
| 1548 | llvmFunc->addFnAttr("tune-cpu" , *tuneCpu); |
| 1549 | |
| 1550 | if (auto reciprocalEstimates = func.getReciprocalEstimates()) |
| 1551 | llvmFunc->addFnAttr("reciprocal-estimates" , *reciprocalEstimates); |
| 1552 | |
| 1553 | if (auto preferVectorWidth = func.getPreferVectorWidth()) |
| 1554 | llvmFunc->addFnAttr("prefer-vector-width" , *preferVectorWidth); |
| 1555 | |
| 1556 | if (auto attr = func.getVscaleRange()) |
| 1557 | llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs( |
| 1558 | Context&: getLLVMContext(), MinValue: attr->getMinRange().getInt(), |
| 1559 | MaxValue: attr->getMaxRange().getInt())); |
| 1560 | |
| 1561 | if (auto unsafeFpMath = func.getUnsafeFpMath()) |
| 1562 | llvmFunc->addFnAttr("unsafe-fp-math" , llvm::toStringRef(*unsafeFpMath)); |
| 1563 | |
| 1564 | if (auto noInfsFpMath = func.getNoInfsFpMath()) |
| 1565 | llvmFunc->addFnAttr("no-infs-fp-math" , llvm::toStringRef(*noInfsFpMath)); |
| 1566 | |
| 1567 | if (auto noNansFpMath = func.getNoNansFpMath()) |
| 1568 | llvmFunc->addFnAttr("no-nans-fp-math" , llvm::toStringRef(*noNansFpMath)); |
| 1569 | |
| 1570 | if (auto approxFuncFpMath = func.getApproxFuncFpMath()) |
| 1571 | llvmFunc->addFnAttr("approx-func-fp-math" , |
| 1572 | llvm::toStringRef(*approxFuncFpMath)); |
| 1573 | |
| 1574 | if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath()) |
| 1575 | llvmFunc->addFnAttr("no-signed-zeros-fp-math" , |
| 1576 | llvm::toStringRef(*noSignedZerosFpMath)); |
| 1577 | |
| 1578 | if (auto denormalFpMath = func.getDenormalFpMath()) |
| 1579 | llvmFunc->addFnAttr("denormal-fp-math" , *denormalFpMath); |
| 1580 | |
| 1581 | if (auto denormalFpMathF32 = func.getDenormalFpMathF32()) |
| 1582 | llvmFunc->addFnAttr("denormal-fp-math-f32" , *denormalFpMathF32); |
| 1583 | |
| 1584 | if (auto fpContract = func.getFpContract()) |
| 1585 | llvmFunc->addFnAttr("fp-contract" , *fpContract); |
| 1586 | |
| 1587 | if (auto instrumentFunctionEntry = func.getInstrumentFunctionEntry()) |
| 1588 | llvmFunc->addFnAttr("instrument-function-entry" , *instrumentFunctionEntry); |
| 1589 | |
| 1590 | if (auto instrumentFunctionExit = func.getInstrumentFunctionExit()) |
| 1591 | llvmFunc->addFnAttr("instrument-function-exit" , *instrumentFunctionExit); |
| 1592 | |
| 1593 | // First, create all blocks so we can jump to them. |
| 1594 | llvm::LLVMContext &llvmContext = llvmFunc->getContext(); |
| 1595 | for (auto &bb : func) { |
| 1596 | auto *llvmBB = llvm::BasicBlock::Create(llvmContext); |
| 1597 | llvmBB->insertInto(llvmFunc); |
| 1598 | mapBlock(&bb, llvmBB); |
| 1599 | } |
| 1600 | |
| 1601 | // Then, convert blocks one by one in topological order to ensure defs are |
| 1602 | // converted before uses. |
| 1603 | auto blocks = getBlocksSortedByDominance(func.getBody()); |
| 1604 | for (Block *bb : blocks) { |
| 1605 | CapturingIRBuilder builder(llvmContext, |
| 1606 | llvm::TargetFolder(llvmModule->getDataLayout())); |
| 1607 | if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder, |
| 1608 | /*recordInsertions=*/true))) |
| 1609 | return failure(); |
| 1610 | } |
| 1611 | |
| 1612 | // After all blocks have been traversed and values mapped, connect the PHI |
| 1613 | // nodes to the results of preceding blocks. |
| 1614 | detail::connectPHINodes(region&: func.getBody(), state: *this); |
| 1615 | |
| 1616 | // Finally, convert dialect attributes attached to the function. |
| 1617 | return convertDialectAttributes(op: func, instructions: {}); |
| 1618 | } |
| 1619 | |
| 1620 | LogicalResult ModuleTranslation::convertDialectAttributes( |
| 1621 | Operation *op, ArrayRef<llvm::Instruction *> instructions) { |
| 1622 | for (NamedAttribute attribute : op->getDialectAttrs()) |
| 1623 | if (failed(Result: iface.amendOperation(op, instructions, attribute, moduleTranslation&: *this))) |
| 1624 | return failure(); |
| 1625 | return success(); |
| 1626 | } |
| 1627 | |
| 1628 | /// Converts memory effect attributes from `func` and attaches them to |
| 1629 | /// `llvmFunc`. |
| 1630 | static void convertFunctionMemoryAttributes(LLVMFuncOp func, |
| 1631 | llvm::Function *llvmFunc) { |
| 1632 | if (!func.getMemoryEffects()) |
| 1633 | return; |
| 1634 | |
| 1635 | MemoryEffectsAttr memEffects = func.getMemoryEffectsAttr(); |
| 1636 | |
| 1637 | // Add memory effects incrementally. |
| 1638 | llvm::MemoryEffects newMemEffects = |
| 1639 | llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, |
| 1640 | convertModRefInfoToLLVM(memEffects.getArgMem())); |
| 1641 | newMemEffects |= llvm::MemoryEffects( |
| 1642 | llvm::MemoryEffects::Location::InaccessibleMem, |
| 1643 | convertModRefInfoToLLVM(memEffects.getInaccessibleMem())); |
| 1644 | newMemEffects |= |
| 1645 | llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, |
| 1646 | convertModRefInfoToLLVM(memEffects.getOther())); |
| 1647 | llvmFunc->setMemoryEffects(newMemEffects); |
| 1648 | } |
| 1649 | |
| 1650 | /// Converts function attributes from `func` and attaches them to `llvmFunc`. |
| 1651 | static void convertFunctionAttributes(LLVMFuncOp func, |
| 1652 | llvm::Function *llvmFunc) { |
| 1653 | if (func.getNoInlineAttr()) |
| 1654 | llvmFunc->addFnAttr(llvm::Attribute::NoInline); |
| 1655 | if (func.getAlwaysInlineAttr()) |
| 1656 | llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline); |
| 1657 | if (func.getOptimizeNoneAttr()) |
| 1658 | llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone); |
| 1659 | if (func.getConvergentAttr()) |
| 1660 | llvmFunc->addFnAttr(llvm::Attribute::Convergent); |
| 1661 | if (func.getNoUnwindAttr()) |
| 1662 | llvmFunc->addFnAttr(llvm::Attribute::NoUnwind); |
| 1663 | if (func.getWillReturnAttr()) |
| 1664 | llvmFunc->addFnAttr(llvm::Attribute::WillReturn); |
| 1665 | if (TargetFeaturesAttr targetFeatAttr = func.getTargetFeaturesAttr()) |
| 1666 | llvmFunc->addFnAttr("target-features" , targetFeatAttr.getFeaturesString()); |
| 1667 | if (FramePointerKindAttr fpAttr = func.getFramePointerAttr()) |
| 1668 | llvmFunc->addFnAttr("frame-pointer" , stringifyFramePointerKind( |
| 1669 | fpAttr.getFramePointerKind())); |
| 1670 | if (UWTableKindAttr uwTableKindAttr = func.getUwtableKindAttr()) |
| 1671 | llvmFunc->setUWTableKind( |
| 1672 | convertUWTableKindToLLVM(uwTableKindAttr.getUwtableKind())); |
| 1673 | convertFunctionMemoryAttributes(func, llvmFunc); |
| 1674 | } |
| 1675 | |
| 1676 | /// Converts function attributes from `func` and attaches them to `llvmFunc`. |
| 1677 | static void convertFunctionKernelAttributes(LLVMFuncOp func, |
| 1678 | llvm::Function *llvmFunc, |
| 1679 | ModuleTranslation &translation) { |
| 1680 | llvm::LLVMContext &llvmContext = llvmFunc->getContext(); |
| 1681 | |
| 1682 | if (VecTypeHintAttr vecTypeHint = func.getVecTypeHintAttr()) { |
| 1683 | Type type = vecTypeHint.getHint().getValue(); |
| 1684 | llvm::Type *llvmType = translation.convertType(type); |
| 1685 | bool isSigned = vecTypeHint.getIsSigned(); |
| 1686 | llvmFunc->setMetadata( |
| 1687 | func.getVecTypeHintAttrName(), |
| 1688 | convertVecTypeHintToMDNode(context&: llvmContext, type: llvmType, isSigned)); |
| 1689 | } |
| 1690 | |
| 1691 | if (std::optional<ArrayRef<int32_t>> workGroupSizeHint = |
| 1692 | func.getWorkGroupSizeHint()) { |
| 1693 | llvmFunc->setMetadata( |
| 1694 | func.getWorkGroupSizeHintAttrName(), |
| 1695 | convertIntegerArrayToMDNode(context&: llvmContext, values: *workGroupSizeHint)); |
| 1696 | } |
| 1697 | |
| 1698 | if (std::optional<ArrayRef<int32_t>> reqdWorkGroupSize = |
| 1699 | func.getReqdWorkGroupSize()) { |
| 1700 | llvmFunc->setMetadata( |
| 1701 | func.getReqdWorkGroupSizeAttrName(), |
| 1702 | convertIntegerArrayToMDNode(context&: llvmContext, values: *reqdWorkGroupSize)); |
| 1703 | } |
| 1704 | |
| 1705 | if (std::optional<uint32_t> intelReqdSubGroupSize = |
| 1706 | func.getIntelReqdSubGroupSize()) { |
| 1707 | llvmFunc->setMetadata( |
| 1708 | func.getIntelReqdSubGroupSizeAttrName(), |
| 1709 | convertIntegerToMDNode(context&: llvmContext, |
| 1710 | value: llvm::APInt(32, *intelReqdSubGroupSize))); |
| 1711 | } |
| 1712 | } |
| 1713 | |
| 1714 | static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, |
| 1715 | llvm::Attribute::AttrKind llvmKind, |
| 1716 | NamedAttribute namedAttr, |
| 1717 | ModuleTranslation &moduleTranslation, |
| 1718 | Location loc) { |
| 1719 | return llvm::TypeSwitch<Attribute, LogicalResult>(namedAttr.getValue()) |
| 1720 | .Case<TypeAttr>([&](auto typeAttr) { |
| 1721 | attrBuilder.addTypeAttr( |
| 1722 | llvmKind, moduleTranslation.convertType(typeAttr.getValue())); |
| 1723 | return success(); |
| 1724 | }) |
| 1725 | .Case<IntegerAttr>([&](auto intAttr) { |
| 1726 | attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt()); |
| 1727 | return success(); |
| 1728 | }) |
| 1729 | .Case<UnitAttr>([&](auto) { |
| 1730 | attrBuilder.addAttribute(llvmKind); |
| 1731 | return success(); |
| 1732 | }) |
| 1733 | .Case<LLVM::ConstantRangeAttr>([&](auto rangeAttr) { |
| 1734 | attrBuilder.addConstantRangeAttr( |
| 1735 | llvmKind, |
| 1736 | llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper())); |
| 1737 | return success(); |
| 1738 | }) |
| 1739 | .Default([loc](auto) { |
| 1740 | return emitError(loc, "unsupported parameter attribute type" ); |
| 1741 | }); |
| 1742 | } |
| 1743 | |
| 1744 | FailureOr<llvm::AttrBuilder> |
| 1745 | ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, |
| 1746 | DictionaryAttr paramAttrs) { |
| 1747 | llvm::AttrBuilder attrBuilder(llvmModule->getContext()); |
| 1748 | auto attrNameToKindMapping = getAttrNameToKindMapping(); |
| 1749 | Location loc = func.getLoc(); |
| 1750 | |
| 1751 | for (auto namedAttr : paramAttrs) { |
| 1752 | auto it = attrNameToKindMapping.find(namedAttr.getName()); |
| 1753 | if (it != attrNameToKindMapping.end()) { |
| 1754 | llvm::Attribute::AttrKind llvmKind = it->second; |
| 1755 | if (failed(convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this, |
| 1756 | loc))) |
| 1757 | return failure(); |
| 1758 | } else if (namedAttr.getNameDialect()) { |
| 1759 | if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this))) |
| 1760 | return failure(); |
| 1761 | } |
| 1762 | } |
| 1763 | |
| 1764 | return attrBuilder; |
| 1765 | } |
| 1766 | |
| 1767 | FailureOr<llvm::AttrBuilder> |
| 1768 | ModuleTranslation::convertParameterAttrs(Location loc, |
| 1769 | DictionaryAttr paramAttrs) { |
| 1770 | llvm::AttrBuilder attrBuilder(llvmModule->getContext()); |
| 1771 | auto attrNameToKindMapping = getAttrNameToKindMapping(); |
| 1772 | |
| 1773 | for (auto namedAttr : paramAttrs) { |
| 1774 | auto it = attrNameToKindMapping.find(namedAttr.getName()); |
| 1775 | if (it != attrNameToKindMapping.end()) { |
| 1776 | llvm::Attribute::AttrKind llvmKind = it->second; |
| 1777 | if (failed(convertParameterAttr(attrBuilder, llvmKind, namedAttr, *this, |
| 1778 | loc))) |
| 1779 | return failure(); |
| 1780 | } |
| 1781 | } |
| 1782 | |
| 1783 | return attrBuilder; |
| 1784 | } |
| 1785 | |
| 1786 | LogicalResult ModuleTranslation::convertFunctionSignatures() { |
| 1787 | // Declare all functions first because there may be function calls that form a |
| 1788 | // call graph with cycles, or global initializers that reference functions. |
| 1789 | for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { |
| 1790 | llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( |
| 1791 | function.getName(), |
| 1792 | cast<llvm::FunctionType>(convertType(function.getFunctionType()))); |
| 1793 | llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee()); |
| 1794 | llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage())); |
| 1795 | llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv())); |
| 1796 | mapFunction(function.getName(), llvmFunc); |
| 1797 | addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc); |
| 1798 | |
| 1799 | // Convert function attributes. |
| 1800 | convertFunctionAttributes(function, llvmFunc); |
| 1801 | |
| 1802 | // Convert function kernel attributes to metadata. |
| 1803 | convertFunctionKernelAttributes(function, llvmFunc, *this); |
| 1804 | |
| 1805 | // Convert function_entry_count attribute to metadata. |
| 1806 | if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount()) |
| 1807 | llvmFunc->setEntryCount(entryCount.value()); |
| 1808 | |
| 1809 | // Convert result attributes. |
| 1810 | if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) { |
| 1811 | DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]); |
| 1812 | FailureOr<llvm::AttrBuilder> attrBuilder = |
| 1813 | convertParameterAttrs(function, -1, resultAttrs); |
| 1814 | if (failed(attrBuilder)) |
| 1815 | return failure(); |
| 1816 | llvmFunc->addRetAttrs(*attrBuilder); |
| 1817 | } |
| 1818 | |
| 1819 | // Convert argument attributes. |
| 1820 | for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) { |
| 1821 | if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) { |
| 1822 | FailureOr<llvm::AttrBuilder> attrBuilder = |
| 1823 | convertParameterAttrs(function, argIdx, argAttrs); |
| 1824 | if (failed(attrBuilder)) |
| 1825 | return failure(); |
| 1826 | llvmArg.addAttrs(*attrBuilder); |
| 1827 | } |
| 1828 | } |
| 1829 | |
| 1830 | // Forward the pass-through attributes to LLVM. |
| 1831 | if (failed(forwardPassthroughAttributes( |
| 1832 | function.getLoc(), function.getPassthrough(), llvmFunc))) |
| 1833 | return failure(); |
| 1834 | |
| 1835 | // Convert visibility attribute. |
| 1836 | llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_())); |
| 1837 | |
| 1838 | // Convert the comdat attribute. |
| 1839 | if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) { |
| 1840 | auto selectorOp = cast<ComdatSelectorOp>( |
| 1841 | SymbolTable::lookupNearestSymbolFrom(function, *comdat)); |
| 1842 | llvmFunc->setComdat(comdatMapping.lookup(selectorOp)); |
| 1843 | } |
| 1844 | |
| 1845 | if (auto gc = function.getGarbageCollector()) |
| 1846 | llvmFunc->setGC(gc->str()); |
| 1847 | |
| 1848 | if (auto unnamedAddr = function.getUnnamedAddr()) |
| 1849 | llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr)); |
| 1850 | |
| 1851 | if (auto alignment = function.getAlignment()) |
| 1852 | llvmFunc->setAlignment(llvm::MaybeAlign(*alignment)); |
| 1853 | |
| 1854 | // Translate the debug information for this function. |
| 1855 | debugTranslation->translate(function, *llvmFunc); |
| 1856 | } |
| 1857 | |
| 1858 | return success(); |
| 1859 | } |
| 1860 | |
| 1861 | LogicalResult ModuleTranslation::convertFunctions() { |
| 1862 | // Convert functions. |
| 1863 | for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { |
| 1864 | // Do not convert external functions, but do process dialect attributes |
| 1865 | // attached to them. |
| 1866 | if (function.isExternal()) { |
| 1867 | if (failed(convertDialectAttributes(function, {}))) |
| 1868 | return failure(); |
| 1869 | continue; |
| 1870 | } |
| 1871 | |
| 1872 | if (failed(convertOneFunction(function))) |
| 1873 | return failure(); |
| 1874 | } |
| 1875 | |
| 1876 | return success(); |
| 1877 | } |
| 1878 | |
| 1879 | LogicalResult ModuleTranslation::convertComdats() { |
| 1880 | for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) { |
| 1881 | for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) { |
| 1882 | llvm::Module *module = getLLVMModule(); |
| 1883 | if (module->getComdatSymbolTable().contains(selectorOp.getSymName())) |
| 1884 | return emitError(selectorOp.getLoc()) |
| 1885 | << "comdat selection symbols must be unique even in different " |
| 1886 | "comdat regions" ; |
| 1887 | llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName()); |
| 1888 | comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat())); |
| 1889 | comdatMapping.try_emplace(selectorOp, comdat); |
| 1890 | } |
| 1891 | } |
| 1892 | return success(); |
| 1893 | } |
| 1894 | |
| 1895 | LogicalResult ModuleTranslation::convertUnresolvedBlockAddress() { |
| 1896 | for (auto &[blockAddressOp, llvmCst] : unresolvedBlockAddressMapping) { |
| 1897 | BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr(); |
| 1898 | llvm::BasicBlock *llvmBlock = lookupBlockAddress(blockAddressAttr); |
| 1899 | assert(llvmBlock && "expected LLVM blocks to be already translated" ); |
| 1900 | |
| 1901 | // Update mapping with new block address constant. |
| 1902 | auto *llvmBlockAddr = llvm::BlockAddress::get( |
| 1903 | lookupFunction(blockAddressAttr.getFunction().getValue()), llvmBlock); |
| 1904 | llvmCst->replaceAllUsesWith(llvmBlockAddr); |
| 1905 | assert(llvmCst->use_empty() && "expected all uses to be replaced" ); |
| 1906 | cast<llvm::GlobalVariable>(llvmCst)->eraseFromParent(); |
| 1907 | } |
| 1908 | unresolvedBlockAddressMapping.clear(); |
| 1909 | return success(); |
| 1910 | } |
| 1911 | |
| 1912 | void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op, |
| 1913 | llvm::Instruction *inst) { |
| 1914 | if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op)) |
| 1915 | inst->setMetadata(KindID: llvm::LLVMContext::MD_access_group, Node: node); |
| 1916 | } |
| 1917 | |
| 1918 | llvm::MDNode * |
| 1919 | ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) { |
| 1920 | auto [scopeIt, scopeInserted] = |
| 1921 | aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr); |
| 1922 | if (!scopeInserted) |
| 1923 | return scopeIt->second; |
| 1924 | llvm::LLVMContext &ctx = llvmModule->getContext(); |
| 1925 | auto dummy = llvm::MDNode::getTemporary(Context&: ctx, MDs: std::nullopt); |
| 1926 | // Convert the domain metadata node if necessary. |
| 1927 | auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace( |
| 1928 | aliasScopeAttr.getDomain(), nullptr); |
| 1929 | if (insertedDomain) { |
| 1930 | llvm::SmallVector<llvm::Metadata *, 2> operands; |
| 1931 | // Placeholder for potential self-reference. |
| 1932 | operands.push_back(Elt: dummy.get()); |
| 1933 | if (StringAttr description = aliasScopeAttr.getDomain().getDescription()) |
| 1934 | operands.push_back(Elt: llvm::MDString::get(ctx, description)); |
| 1935 | domainIt->second = llvm::MDNode::get(ctx, operands); |
| 1936 | // Self-reference for uniqueness. |
| 1937 | llvm::Metadata *replacement; |
| 1938 | if (auto stringAttr = |
| 1939 | dyn_cast<StringAttr>(aliasScopeAttr.getDomain().getId())) |
| 1940 | replacement = llvm::MDString::get(ctx, stringAttr.getValue()); |
| 1941 | else |
| 1942 | replacement = domainIt->second; |
| 1943 | domainIt->second->replaceOperandWith(0, replacement); |
| 1944 | } |
| 1945 | // Convert the scope metadata node. |
| 1946 | assert(domainIt->second && "Scope's domain should already be valid" ); |
| 1947 | llvm::SmallVector<llvm::Metadata *, 3> operands; |
| 1948 | // Placeholder for potential self-reference. |
| 1949 | operands.push_back(Elt: dummy.get()); |
| 1950 | operands.push_back(Elt: domainIt->second); |
| 1951 | if (StringAttr description = aliasScopeAttr.getDescription()) |
| 1952 | operands.push_back(Elt: llvm::MDString::get(ctx, description)); |
| 1953 | scopeIt->second = llvm::MDNode::get(ctx, operands); |
| 1954 | // Self-reference for uniqueness. |
| 1955 | llvm::Metadata *replacement; |
| 1956 | if (auto stringAttr = dyn_cast<StringAttr>(aliasScopeAttr.getId())) |
| 1957 | replacement = llvm::MDString::get(ctx, stringAttr.getValue()); |
| 1958 | else |
| 1959 | replacement = scopeIt->second; |
| 1960 | scopeIt->second->replaceOperandWith(0, replacement); |
| 1961 | return scopeIt->second; |
| 1962 | } |
| 1963 | |
| 1964 | llvm::MDNode *ModuleTranslation::getOrCreateAliasScopes( |
| 1965 | ArrayRef<AliasScopeAttr> aliasScopeAttrs) { |
| 1966 | SmallVector<llvm::Metadata *> nodes; |
| 1967 | nodes.reserve(N: aliasScopeAttrs.size()); |
| 1968 | for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs) |
| 1969 | nodes.push_back(getOrCreateAliasScope(aliasScopeAttr)); |
| 1970 | return llvm::MDNode::get(Context&: getLLVMContext(), MDs: nodes); |
| 1971 | } |
| 1972 | |
| 1973 | void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op, |
| 1974 | llvm::Instruction *inst) { |
| 1975 | auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) { |
| 1976 | if (!aliasScopeAttrs || aliasScopeAttrs.empty()) |
| 1977 | return; |
| 1978 | llvm::MDNode *node = getOrCreateAliasScopes( |
| 1979 | aliasScopeAttrs: llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>())); |
| 1980 | inst->setMetadata(KindID: kind, Node: node); |
| 1981 | }; |
| 1982 | |
| 1983 | populateScopeMetadata(op.getAliasScopesOrNull(), |
| 1984 | llvm::LLVMContext::MD_alias_scope); |
| 1985 | populateScopeMetadata(op.getNoAliasScopesOrNull(), |
| 1986 | llvm::LLVMContext::MD_noalias); |
| 1987 | } |
| 1988 | |
| 1989 | llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const { |
| 1990 | return tbaaMetadataMapping.lookup(Val: tbaaAttr); |
| 1991 | } |
| 1992 | |
| 1993 | void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op, |
| 1994 | llvm::Instruction *inst) { |
| 1995 | ArrayAttr tagRefs = op.getTBAATagsOrNull(); |
| 1996 | if (!tagRefs || tagRefs.empty()) |
| 1997 | return; |
| 1998 | |
| 1999 | // LLVM IR currently does not support attaching more than one TBAA access tag |
| 2000 | // to a memory accessing instruction. It may be useful to support this in |
| 2001 | // future, but for the time being just ignore the metadata if MLIR operation |
| 2002 | // has multiple access tags. |
| 2003 | if (tagRefs.size() > 1) { |
| 2004 | op.emitWarning() << "TBAA access tags were not translated, because LLVM " |
| 2005 | "IR only supports a single tag per instruction" ; |
| 2006 | return; |
| 2007 | } |
| 2008 | |
| 2009 | llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0])); |
| 2010 | inst->setMetadata(KindID: llvm::LLVMContext::MD_tbaa, Node: node); |
| 2011 | } |
| 2012 | |
| 2013 | void ModuleTranslation::setDereferenceableMetadata( |
| 2014 | DereferenceableOpInterface op, llvm::Instruction *inst) { |
| 2015 | DereferenceableAttr derefAttr = op.getDereferenceableOrNull(); |
| 2016 | if (!derefAttr) |
| 2017 | return; |
| 2018 | |
| 2019 | llvm::MDNode *derefSizeNode = llvm::MDNode::get( |
| 2020 | Context&: getLLVMContext(), |
| 2021 | MDs: llvm::ConstantAsMetadata::get(C: llvm::ConstantInt::get( |
| 2022 | llvm::IntegerType::get(C&: getLLVMContext(), NumBits: 64), derefAttr.getBytes()))); |
| 2023 | unsigned kindId = derefAttr.getMayBeNull() |
| 2024 | ? llvm::LLVMContext::MD_dereferenceable_or_null |
| 2025 | : llvm::LLVMContext::MD_dereferenceable; |
| 2026 | inst->setMetadata(KindID: kindId, Node: derefSizeNode); |
| 2027 | } |
| 2028 | |
| 2029 | void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { |
| 2030 | DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull(); |
| 2031 | if (!weightsAttr) |
| 2032 | return; |
| 2033 | |
| 2034 | llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op); |
| 2035 | assert(inst && "expected the operation to have a mapping to an instruction" ); |
| 2036 | SmallVector<uint32_t> weights(weightsAttr.asArrayRef()); |
| 2037 | inst->setMetadata( |
| 2038 | KindID: llvm::LLVMContext::MD_prof, |
| 2039 | Node: llvm::MDBuilder(getLLVMContext()).createBranchWeights(Weights: weights)); |
| 2040 | } |
| 2041 | |
| 2042 | LogicalResult ModuleTranslation::createTBAAMetadata() { |
| 2043 | llvm::LLVMContext &ctx = llvmModule->getContext(); |
| 2044 | llvm::IntegerType *offsetTy = llvm::IntegerType::get(C&: ctx, NumBits: 64); |
| 2045 | |
| 2046 | // Walk the entire module and create all metadata nodes for the TBAA |
| 2047 | // attributes. The code below relies on two invariants of the |
| 2048 | // `AttrTypeWalker`: |
| 2049 | // 1. Attributes are visited in post-order: Since the attributes create a DAG, |
| 2050 | // this ensures that any lookups into `tbaaMetadataMapping` for child |
| 2051 | // attributes succeed. |
| 2052 | // 2. Attributes are only ever visited once: This way we don't leak any |
| 2053 | // LLVM metadata instances. |
| 2054 | AttrTypeWalker walker; |
| 2055 | walker.addWalk(callback: [&](TBAARootAttr root) { |
| 2056 | tbaaMetadataMapping.insert( |
| 2057 | {root, llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(ctx, root.getId()))}); |
| 2058 | }); |
| 2059 | |
| 2060 | walker.addWalk(callback: [&](TBAATypeDescriptorAttr descriptor) { |
| 2061 | SmallVector<llvm::Metadata *> operands; |
| 2062 | operands.push_back(Elt: llvm::MDString::get(ctx, descriptor.getId())); |
| 2063 | for (TBAAMemberAttr member : descriptor.getMembers()) { |
| 2064 | operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc())); |
| 2065 | operands.push_back(llvm::ConstantAsMetadata::get( |
| 2066 | llvm::ConstantInt::get(offsetTy, member.getOffset()))); |
| 2067 | } |
| 2068 | |
| 2069 | tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(Context&: ctx, MDs: operands)}); |
| 2070 | }); |
| 2071 | |
| 2072 | walker.addWalk(callback: [&](TBAATagAttr tag) { |
| 2073 | SmallVector<llvm::Metadata *> operands; |
| 2074 | |
| 2075 | operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: tag.getBaseType())); |
| 2076 | operands.push_back(Elt: tbaaMetadataMapping.lookup(Val: tag.getAccessType())); |
| 2077 | |
| 2078 | operands.push_back(Elt: llvm::ConstantAsMetadata::get( |
| 2079 | C: llvm::ConstantInt::get(offsetTy, tag.getOffset()))); |
| 2080 | if (tag.getConstant()) |
| 2081 | operands.push_back( |
| 2082 | Elt: llvm::ConstantAsMetadata::get(C: llvm::ConstantInt::get(Ty: offsetTy, V: 1))); |
| 2083 | |
| 2084 | tbaaMetadataMapping.insert({tag, llvm::MDNode::get(Context&: ctx, MDs: operands)}); |
| 2085 | }); |
| 2086 | |
| 2087 | mlirModule->walk(callback: [&](AliasAnalysisOpInterface analysisOpInterface) { |
| 2088 | if (auto attr = analysisOpInterface.getTBAATagsOrNull()) |
| 2089 | walker.walk(attr); |
| 2090 | }); |
| 2091 | |
| 2092 | return success(); |
| 2093 | } |
| 2094 | |
| 2095 | LogicalResult ModuleTranslation::createIdentMetadata() { |
| 2096 | if (auto attr = mlirModule->getAttrOfType<StringAttr>( |
| 2097 | LLVMDialect::getIdentAttrName())) { |
| 2098 | StringRef ident = attr; |
| 2099 | llvm::LLVMContext &ctx = llvmModule->getContext(); |
| 2100 | llvm::NamedMDNode *namedMd = |
| 2101 | llvmModule->getOrInsertNamedMetadata(LLVMDialect::getIdentAttrName()); |
| 2102 | llvm::MDNode *md = llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(Context&: ctx, Str: ident)); |
| 2103 | namedMd->addOperand(M: md); |
| 2104 | } |
| 2105 | |
| 2106 | return success(); |
| 2107 | } |
| 2108 | |
| 2109 | LogicalResult ModuleTranslation::createCommandlineMetadata() { |
| 2110 | if (auto attr = mlirModule->getAttrOfType<StringAttr>( |
| 2111 | LLVMDialect::getCommandlineAttrName())) { |
| 2112 | StringRef cmdLine = attr; |
| 2113 | llvm::LLVMContext &ctx = llvmModule->getContext(); |
| 2114 | llvm::NamedMDNode *nmd = llvmModule->getOrInsertNamedMetadata( |
| 2115 | LLVMDialect::getCommandlineAttrName()); |
| 2116 | llvm::MDNode *md = |
| 2117 | llvm::MDNode::get(Context&: ctx, MDs: llvm::MDString::get(Context&: ctx, Str: cmdLine)); |
| 2118 | nmd->addOperand(M: md); |
| 2119 | } |
| 2120 | |
| 2121 | return success(); |
| 2122 | } |
| 2123 | |
| 2124 | LogicalResult ModuleTranslation::createDependentLibrariesMetadata() { |
| 2125 | if (auto dependentLibrariesAttr = mlirModule->getDiscardableAttr( |
| 2126 | LLVM::LLVMDialect::getDependentLibrariesAttrName())) { |
| 2127 | auto *nmd = |
| 2128 | llvmModule->getOrInsertNamedMetadata(Name: "llvm.dependent-libraries" ); |
| 2129 | llvm::LLVMContext &ctx = llvmModule->getContext(); |
| 2130 | for (auto libAttr : |
| 2131 | cast<ArrayAttr>(dependentLibrariesAttr).getAsRange<StringAttr>()) { |
| 2132 | auto *md = |
| 2133 | llvm::MDNode::get(ctx, llvm::MDString::get(ctx, libAttr.getValue())); |
| 2134 | nmd->addOperand(md); |
| 2135 | } |
| 2136 | } |
| 2137 | return success(); |
| 2138 | } |
| 2139 | |
| 2140 | void ModuleTranslation::setLoopMetadata(Operation *op, |
| 2141 | llvm::Instruction *inst) { |
| 2142 | LoopAnnotationAttr attr = |
| 2143 | TypeSwitch<Operation *, LoopAnnotationAttr>(op) |
| 2144 | .Case<LLVM::BrOp, LLVM::CondBrOp>( |
| 2145 | [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); }); |
| 2146 | if (!attr) |
| 2147 | return; |
| 2148 | llvm::MDNode *loopMD = |
| 2149 | loopAnnotationTranslation->translateLoopAnnotation(attr, op); |
| 2150 | inst->setMetadata(KindID: llvm::LLVMContext::MD_loop, Node: loopMD); |
| 2151 | } |
| 2152 | |
| 2153 | void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) { |
| 2154 | auto iface = cast<DisjointFlagInterface>(op); |
| 2155 | // We do a dyn_cast here in case the value got folded into a constant. |
| 2156 | if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(Val: value)) |
| 2157 | disjointInst->setIsDisjoint(iface.getIsDisjoint()); |
| 2158 | } |
| 2159 | |
| 2160 | llvm::Type *ModuleTranslation::convertType(Type type) { |
| 2161 | return typeTranslator.translateType(type); |
| 2162 | } |
| 2163 | |
| 2164 | /// A helper to look up remapped operands in the value remapping table. |
| 2165 | SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) { |
| 2166 | SmallVector<llvm::Value *> remapped; |
| 2167 | remapped.reserve(N: values.size()); |
| 2168 | for (Value v : values) |
| 2169 | remapped.push_back(Elt: lookupValue(value: v)); |
| 2170 | return remapped; |
| 2171 | } |
| 2172 | |
| 2173 | llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() { |
| 2174 | if (!ompBuilder) { |
| 2175 | ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(args&: *llvmModule); |
| 2176 | ompBuilder->initialize(); |
| 2177 | |
| 2178 | // Flags represented as top-level OpenMP dialect attributes are set in |
| 2179 | // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set |
| 2180 | // the default configuration. |
| 2181 | ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig( |
| 2182 | /* IsTargetDevice = */ false, /* IsGPU = */ false, |
| 2183 | /* OpenMPOffloadMandatory = */ false, |
| 2184 | /* HasRequiresReverseOffload = */ false, |
| 2185 | /* HasRequiresUnifiedAddress = */ false, |
| 2186 | /* HasRequiresUnifiedSharedMemory = */ false, |
| 2187 | /* HasRequiresDynamicAllocators = */ false)); |
| 2188 | } |
| 2189 | return ompBuilder.get(); |
| 2190 | } |
| 2191 | |
| 2192 | llvm::DILocation *ModuleTranslation::translateLoc(Location loc, |
| 2193 | llvm::DILocalScope *scope) { |
| 2194 | return debugTranslation->translateLoc(loc, scope); |
| 2195 | } |
| 2196 | |
| 2197 | llvm::DIExpression * |
| 2198 | ModuleTranslation::translateExpression(LLVM::DIExpressionAttr attr) { |
| 2199 | return debugTranslation->translateExpression(attr); |
| 2200 | } |
| 2201 | |
| 2202 | llvm::DIGlobalVariableExpression * |
| 2203 | ModuleTranslation::translateGlobalVariableExpression( |
| 2204 | LLVM::DIGlobalVariableExpressionAttr attr) { |
| 2205 | return debugTranslation->translateGlobalVariableExpression(attr); |
| 2206 | } |
| 2207 | |
| 2208 | llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) { |
| 2209 | return debugTranslation->translate(attr); |
| 2210 | } |
| 2211 | |
| 2212 | llvm::RoundingMode |
| 2213 | ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) { |
| 2214 | return convertRoundingModeToLLVM(rounding); |
| 2215 | } |
| 2216 | |
| 2217 | llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior( |
| 2218 | LLVM::FPExceptionBehavior exceptionBehavior) { |
| 2219 | return convertFPExceptionBehaviorToLLVM(exceptionBehavior); |
| 2220 | } |
| 2221 | |
| 2222 | llvm::NamedMDNode * |
| 2223 | ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) { |
| 2224 | return llvmModule->getOrInsertNamedMetadata(Name: name); |
| 2225 | } |
| 2226 | |
| 2227 | void ModuleTranslation::StackFrame::anchor() {} |
| 2228 | |
| 2229 | static std::unique_ptr<llvm::Module> |
| 2230 | prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, |
| 2231 | StringRef name) { |
| 2232 | m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>(); |
| 2233 | auto llvmModule = std::make_unique<llvm::Module>(args&: name, args&: llvmContext); |
| 2234 | // ModuleTranslation can currently only construct modules in the old debug |
| 2235 | // info format, so set the flag accordingly. |
| 2236 | llvmModule->setNewDbgInfoFormatFlag(false); |
| 2237 | if (auto dataLayoutAttr = |
| 2238 | m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) { |
| 2239 | llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue()); |
| 2240 | } else { |
| 2241 | FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout("" )); |
| 2242 | if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) { |
| 2243 | if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) { |
| 2244 | llvmDataLayout = |
| 2245 | translateDataLayout(spec, DataLayout(iface), m->getLoc()); |
| 2246 | } |
| 2247 | } else if (auto mod = dyn_cast<ModuleOp>(m)) { |
| 2248 | if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) { |
| 2249 | llvmDataLayout = |
| 2250 | translateDataLayout(spec, DataLayout(mod), m->getLoc()); |
| 2251 | } |
| 2252 | } |
| 2253 | if (failed(Result: llvmDataLayout)) |
| 2254 | return nullptr; |
| 2255 | llvmModule->setDataLayout(*llvmDataLayout); |
| 2256 | } |
| 2257 | if (auto targetTripleAttr = |
| 2258 | m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) |
| 2259 | llvmModule->setTargetTriple( |
| 2260 | llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue())); |
| 2261 | |
| 2262 | return llvmModule; |
| 2263 | } |
| 2264 | |
| 2265 | std::unique_ptr<llvm::Module> |
| 2266 | mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, |
| 2267 | StringRef name, bool disableVerification) { |
| 2268 | if (!satisfiesLLVMModule(op: module)) { |
| 2269 | module->emitOpError(message: "can not be translated to an LLVMIR module" ); |
| 2270 | return nullptr; |
| 2271 | } |
| 2272 | |
| 2273 | std::unique_ptr<llvm::Module> llvmModule = |
| 2274 | prepareLLVMModule(m: module, llvmContext, name); |
| 2275 | if (!llvmModule) |
| 2276 | return nullptr; |
| 2277 | |
| 2278 | LLVM::ensureDistinctSuccessors(op: module); |
| 2279 | LLVM::legalizeDIExpressionsRecursively(op: module); |
| 2280 | |
| 2281 | ModuleTranslation translator(module, std::move(llvmModule)); |
| 2282 | llvm::IRBuilder<llvm::TargetFolder> llvmBuilder( |
| 2283 | llvmContext, |
| 2284 | llvm::TargetFolder(translator.getLLVMModule()->getDataLayout())); |
| 2285 | |
| 2286 | // Convert module before functions and operations inside, so dialect |
| 2287 | // attributes can be used to change dialect-specific global configurations via |
| 2288 | // `amendOperation()`. These configurations can then influence the translation |
| 2289 | // of operations afterwards. |
| 2290 | if (failed(Result: translator.convertOperation(op&: *module, builder&: llvmBuilder))) |
| 2291 | return nullptr; |
| 2292 | |
| 2293 | if (failed(Result: translator.convertComdats())) |
| 2294 | return nullptr; |
| 2295 | if (failed(Result: translator.convertFunctionSignatures())) |
| 2296 | return nullptr; |
| 2297 | if (failed(Result: translator.convertGlobalsAndAliases())) |
| 2298 | return nullptr; |
| 2299 | if (failed(Result: translator.createTBAAMetadata())) |
| 2300 | return nullptr; |
| 2301 | if (failed(Result: translator.createIdentMetadata())) |
| 2302 | return nullptr; |
| 2303 | if (failed(Result: translator.createCommandlineMetadata())) |
| 2304 | return nullptr; |
| 2305 | if (failed(Result: translator.createDependentLibrariesMetadata())) |
| 2306 | return nullptr; |
| 2307 | |
| 2308 | // Convert other top-level operations if possible. |
| 2309 | for (Operation &o : getModuleBody(module).getOperations()) { |
| 2310 | if (!isa<LLVM::LLVMFuncOp, LLVM::AliasOp, LLVM::GlobalOp, |
| 2311 | LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) && |
| 2312 | !o.hasTrait<OpTrait::IsTerminator>() && |
| 2313 | failed(translator.convertOperation(o, llvmBuilder))) { |
| 2314 | return nullptr; |
| 2315 | } |
| 2316 | } |
| 2317 | |
| 2318 | // Operations in function bodies with symbolic references must be converted |
| 2319 | // after the top-level operations they refer to are declared, so we do it |
| 2320 | // last. |
| 2321 | if (failed(Result: translator.convertFunctions())) |
| 2322 | return nullptr; |
| 2323 | |
| 2324 | // Now that all MLIR blocks are resolved into LLVM ones, patch block address |
| 2325 | // constants to point to the correct blocks. |
| 2326 | if (failed(Result: translator.convertUnresolvedBlockAddress())) |
| 2327 | return nullptr; |
| 2328 | |
| 2329 | // Once we've finished constructing elements in the module, we should convert |
| 2330 | // it to use the debug info format desired by LLVM. |
| 2331 | // See https://llvm.org/docs/RemoveDIsDebugInfo.html |
| 2332 | translator.llvmModule->setIsNewDbgInfoFormat(true); |
| 2333 | |
| 2334 | // Add the necessary debug info module flags, if they were not encoded in MLIR |
| 2335 | // beforehand. |
| 2336 | translator.debugTranslation->addModuleFlagsIfNotPresent(); |
| 2337 | |
| 2338 | if (!disableVerification && |
| 2339 | llvm::verifyModule(M: *translator.llvmModule, OS: &llvm::errs())) |
| 2340 | return nullptr; |
| 2341 | |
| 2342 | return std::move(translator.llvmModule); |
| 2343 | } |
| 2344 | |