1//===- ModuleImport.cpp - LLVM to MLIR conversion ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the import of an LLVM IR module into an LLVM dialect
10// module.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Target/LLVMIR/ModuleImport.h"
15#include "mlir/Target/LLVMIR/Import.h"
16
17#include "AttrKindDetail.h"
18#include "DataLayoutImporter.h"
19#include "DebugImporter.h"
20#include "LoopAnnotationImporter.h"
21
22#include "mlir/Dialect/DLTI/DLTI.h"
23#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/Interfaces/DataLayoutInterfaces.h"
27#include "mlir/Tools/mlir-translate/Translation.h"
28
29#include "llvm/ADT/DepthFirstIterator.h"
30#include "llvm/ADT/PostOrderIterator.h"
31#include "llvm/ADT/ScopeExit.h"
32#include "llvm/ADT/StringSet.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/IR/Comdat.h"
35#include "llvm/IR/Constants.h"
36#include "llvm/IR/InlineAsm.h"
37#include "llvm/IR/InstIterator.h"
38#include "llvm/IR/Instructions.h"
39#include "llvm/IR/IntrinsicInst.h"
40#include "llvm/IR/Metadata.h"
41#include "llvm/IR/Operator.h"
42#include "llvm/Support/ModRef.h"
43
44using namespace mlir;
45using namespace mlir::LLVM;
46using namespace mlir::LLVM::detail;
47
48#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
49
50// Utility to print an LLVM value as a string for passing to emitError().
51// FIXME: Diagnostic should be able to natively handle types that have
52// operator << (raw_ostream&) defined.
53static std::string diag(const llvm::Value &value) {
54 std::string str;
55 llvm::raw_string_ostream os(str);
56 os << value;
57 return os.str();
58}
59
60// Utility to print an LLVM metadata node as a string for passing
61// to emitError(). The module argument is needed to print the nodes
62// canonically numbered.
63static std::string diagMD(const llvm::Metadata *node,
64 const llvm::Module *module) {
65 std::string str;
66 llvm::raw_string_ostream os(str);
67 node->print(OS&: os, M: module, /*IsForDebug=*/true);
68 return os.str();
69}
70
71/// Returns the name of the global_ctors global variables.
72static constexpr StringRef getGlobalCtorsVarName() {
73 return "llvm.global_ctors";
74}
75
76/// Returns the name of the global_dtors global variables.
77static constexpr StringRef getGlobalDtorsVarName() {
78 return "llvm.global_dtors";
79}
80
81/// Returns the symbol name for the module-level comdat operation. It must not
82/// conflict with the user namespace.
83static constexpr StringRef getGlobalComdatOpName() {
84 return "__llvm_global_comdat";
85}
86
87/// Converts the sync scope identifier of `inst` to the string representation
88/// necessary to build an atomic LLVM dialect operation. Returns the empty
89/// string if the operation has either no sync scope or the default system-level
90/// sync scope attached. The atomic operations only set their sync scope
91/// attribute if they have a non-default sync scope attached.
92static StringRef getLLVMSyncScope(llvm::Instruction *inst) {
93 std::optional<llvm::SyncScope::ID> syncScopeID =
94 llvm::getAtomicSyncScopeID(I: inst);
95 if (!syncScopeID)
96 return "";
97
98 // Search the sync scope name for the given identifier. The default
99 // system-level sync scope thereby maps to the empty string.
100 SmallVector<StringRef> syncScopeName;
101 llvm::LLVMContext &llvmContext = inst->getContext();
102 llvmContext.getSyncScopeNames(SSNs&: syncScopeName);
103 auto *it = llvm::find_if(Range&: syncScopeName, P: [&](StringRef name) {
104 return *syncScopeID == llvmContext.getOrInsertSyncScopeID(SSN: name);
105 });
106 if (it != syncScopeName.end())
107 return *it;
108 llvm_unreachable("incorrect sync scope identifier");
109}
110
111/// Converts an array of unsigned indices to a signed integer position array.
112static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
113 SmallVector<int64_t> position;
114 llvm::append_range(C&: position, R&: indices);
115 return position;
116}
117
118/// Converts the LLVM instructions that have a generated MLIR builder. Using a
119/// static implementation method called from the module import ensures the
120/// builders have to use the `moduleImport` argument and cannot directly call
121/// import methods. As a result, both the intrinsic and the instruction MLIR
122/// builders have to use the `moduleImport` argument and none of them has direct
123/// access to the private module import methods.
124static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
125 llvm::Instruction *inst,
126 ModuleImport &moduleImport,
127 LLVMImportInterface &iface) {
128 // Copy the operands to an LLVM operands array reference for conversion.
129 SmallVector<llvm::Value *> operands(inst->operands());
130 ArrayRef<llvm::Value *> llvmOperands(operands);
131
132 // Convert all instructions that provide an MLIR builder.
133 if (iface.isConvertibleInstruction(id: inst->getOpcode()))
134 return iface.convertInstruction(builder&: odsBuilder, inst, llvmOperands,
135 moduleImport);
136 // TODO: Implement the `convertInstruction` hooks in the
137 // `LLVMDialectLLVMIRImportInterface` and move the following include there.
138#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
139 return failure();
140}
141
142/// Get a topologically sorted list of blocks for the given basic blocks.
143static SetVector<llvm::BasicBlock *>
144getTopologicallySortedBlocks(ArrayRef<llvm::BasicBlock *> basicBlocks) {
145 SetVector<llvm::BasicBlock *> blocks;
146 for (llvm::BasicBlock *basicBlock : basicBlocks) {
147 if (!blocks.contains(basicBlock)) {
148 llvm::ReversePostOrderTraversal<llvm::BasicBlock *> traversal(basicBlock);
149 blocks.insert(traversal.begin(), traversal.end());
150 }
151 }
152 assert(blocks.size() == basicBlocks.size() && "some blocks are not sorted");
153 return blocks;
154}
155
156ModuleImport::ModuleImport(ModuleOp mlirModule,
157 std::unique_ptr<llvm::Module> llvmModule,
158 bool emitExpensiveWarnings,
159 bool importEmptyDICompositeTypes)
160 : builder(mlirModule->getContext()), context(mlirModule->getContext()),
161 mlirModule(mlirModule), llvmModule(std::move(llvmModule)),
162 iface(mlirModule->getContext()),
163 typeTranslator(*mlirModule->getContext()),
164 debugImporter(std::make_unique<DebugImporter>(
165 mlirModule, importEmptyDICompositeTypes)),
166 loopAnnotationImporter(
167 std::make_unique<LoopAnnotationImporter>(args&: *this, args&: builder)),
168 emitExpensiveWarnings(emitExpensiveWarnings) {
169 builder.setInsertionPointToStart(mlirModule.getBody());
170}
171
172ComdatOp ModuleImport::getGlobalComdatOp() {
173 if (globalComdatOp)
174 return globalComdatOp;
175
176 OpBuilder::InsertionGuard guard(builder);
177 builder.setInsertionPointToEnd(mlirModule.getBody());
178 globalComdatOp =
179 builder.create<ComdatOp>(mlirModule.getLoc(), getGlobalComdatOpName());
180 globalInsertionOp = globalComdatOp;
181 return globalComdatOp;
182}
183
184LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) {
185 Location loc = mlirModule.getLoc();
186
187 // If `node` is a valid TBAA root node, then return its optional identity
188 // string, otherwise return failure.
189 auto getIdentityIfRootNode =
190 [&](const llvm::MDNode *node) -> FailureOr<std::optional<StringRef>> {
191 // Root node, e.g.:
192 // !0 = !{!"Simple C/C++ TBAA"}
193 // !1 = !{}
194 if (node->getNumOperands() > 1)
195 return failure();
196 // If the operand is MDString, then assume that this is a root node.
197 if (node->getNumOperands() == 1)
198 if (const auto *op0 = dyn_cast<const llvm::MDString>(Val: node->getOperand(I: 0)))
199 return std::optional<StringRef>{op0->getString()};
200 return std::optional<StringRef>{};
201 };
202
203 // If `node` looks like a TBAA type descriptor metadata,
204 // then return true, if it is a valid node, and false otherwise.
205 // If it does not look like a TBAA type descriptor metadata, then
206 // return std::nullopt.
207 // If `identity` and `memberTypes/Offsets` are non-null, then they will
208 // contain the converted metadata operands for a valid TBAA node (i.e. when
209 // true is returned).
210 auto isTypeDescriptorNode = [&](const llvm::MDNode *node,
211 StringRef *identity = nullptr,
212 SmallVectorImpl<TBAAMemberAttr> *members =
213 nullptr) -> std::optional<bool> {
214 unsigned numOperands = node->getNumOperands();
215 // Type descriptor, e.g.:
216 // !1 = !{!"int", !0, /*optional*/i64 0} /* scalar int type */
217 // !2 = !{!"agg_t", !1, i64 0} /* struct agg_t { int x; } */
218 if (numOperands < 2)
219 return std::nullopt;
220
221 // TODO: support "new" format (D41501) for type descriptors,
222 // where the first operand is an MDNode.
223 const auto *identityNode =
224 dyn_cast<const llvm::MDString>(Val: node->getOperand(I: 0));
225 if (!identityNode)
226 return std::nullopt;
227
228 // This should be a type descriptor node.
229 if (identity)
230 *identity = identityNode->getString();
231
232 for (unsigned pairNum = 0, e = numOperands / 2; pairNum < e; ++pairNum) {
233 const auto *memberNode =
234 dyn_cast<const llvm::MDNode>(Val: node->getOperand(I: 2 * pairNum + 1));
235 if (!memberNode) {
236 emitError(loc) << "operand '" << 2 * pairNum + 1 << "' must be MDNode: "
237 << diagMD(node, module: llvmModule.get());
238 return false;
239 }
240 int64_t offset = 0;
241 if (2 * pairNum + 2 >= numOperands) {
242 // Allow for optional 0 offset in 2-operand nodes.
243 if (numOperands != 2) {
244 emitError(loc) << "missing member offset: "
245 << diagMD(node, module: llvmModule.get());
246 return false;
247 }
248 } else {
249 auto *offsetCI = llvm::mdconst::dyn_extract<llvm::ConstantInt>(
250 MD: node->getOperand(I: 2 * pairNum + 2));
251 if (!offsetCI) {
252 emitError(loc) << "operand '" << 2 * pairNum + 2
253 << "' must be ConstantInt: "
254 << diagMD(node, module: llvmModule.get());
255 return false;
256 }
257 offset = offsetCI->getZExtValue();
258 }
259
260 if (members)
261 members->push_back(TBAAMemberAttr::get(
262 cast<TBAANodeAttr>(tbaaMapping.lookup(memberNode)), offset));
263 }
264
265 return true;
266 };
267
268 // If `node` looks like a TBAA access tag metadata,
269 // then return true, if it is a valid node, and false otherwise.
270 // If it does not look like a TBAA access tag metadata, then
271 // return std::nullopt.
272 // If the other arguments are non-null, then they will contain
273 // the converted metadata operands for a valid TBAA node (i.e. when true is
274 // returned).
275 auto isTagNode = [&](const llvm::MDNode *node,
276 TBAATypeDescriptorAttr *baseAttr = nullptr,
277 TBAATypeDescriptorAttr *accessAttr = nullptr,
278 int64_t *offset = nullptr,
279 bool *isConstant = nullptr) -> std::optional<bool> {
280 // Access tag, e.g.:
281 // !3 = !{!1, !1, i64 0} /* scalar int access */
282 // !4 = !{!2, !1, i64 0} /* agg_t::x access */
283 //
284 // Optional 4th argument is ConstantInt 0/1 identifying whether
285 // the location being accessed is "constant" (see for details:
286 // https://llvm.org/docs/LangRef.html#representation).
287 unsigned numOperands = node->getNumOperands();
288 if (numOperands != 3 && numOperands != 4)
289 return std::nullopt;
290 const auto *baseMD = dyn_cast<const llvm::MDNode>(Val: node->getOperand(I: 0));
291 const auto *accessMD = dyn_cast<const llvm::MDNode>(Val: node->getOperand(I: 1));
292 auto *offsetCI =
293 llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: node->getOperand(I: 2));
294 if (!baseMD || !accessMD || !offsetCI)
295 return std::nullopt;
296 // TODO: support "new" TBAA format, if needed (see D41501).
297 // In the "old" format the first operand of the access type
298 // metadata is MDString. We have to distinguish the formats,
299 // because access tags have the same structure, but different
300 // meaning for the operands.
301 if (accessMD->getNumOperands() < 1 ||
302 !isa<llvm::MDString>(Val: accessMD->getOperand(I: 0)))
303 return std::nullopt;
304 bool isConst = false;
305 if (numOperands == 4) {
306 auto *isConstantCI =
307 llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: node->getOperand(I: 3));
308 if (!isConstantCI) {
309 emitError(loc) << "operand '3' must be ConstantInt: "
310 << diagMD(node, module: llvmModule.get());
311 return false;
312 }
313 isConst = isConstantCI->getValue()[0];
314 }
315 if (baseAttr)
316 *baseAttr = cast<TBAATypeDescriptorAttr>(tbaaMapping.lookup(baseMD));
317 if (accessAttr)
318 *accessAttr = cast<TBAATypeDescriptorAttr>(tbaaMapping.lookup(accessMD));
319 if (offset)
320 *offset = offsetCI->getZExtValue();
321 if (isConstant)
322 *isConstant = isConst;
323 return true;
324 };
325
326 // Do a post-order walk over the TBAA Graph. Since a correct TBAA Graph is a
327 // DAG, a post-order walk guarantees that we convert any metadata node we
328 // depend on, prior to converting the current node.
329 DenseSet<const llvm::MDNode *> seen;
330 SmallVector<const llvm::MDNode *> workList;
331 workList.push_back(Elt: node);
332 while (!workList.empty()) {
333 const llvm::MDNode *current = workList.back();
334 if (tbaaMapping.contains(Val: current)) {
335 // Already converted. Just pop from the worklist.
336 workList.pop_back();
337 continue;
338 }
339
340 // If any child of this node is not yet converted, don't pop the current
341 // node from the worklist but push the not-yet-converted children in the
342 // front of the worklist.
343 bool anyChildNotConverted = false;
344 for (const llvm::MDOperand &operand : current->operands())
345 if (auto *childNode = dyn_cast_or_null<const llvm::MDNode>(Val: operand.get()))
346 if (!tbaaMapping.contains(Val: childNode)) {
347 workList.push_back(Elt: childNode);
348 anyChildNotConverted = true;
349 }
350
351 if (anyChildNotConverted) {
352 // If this is the second time we failed to convert an element in the
353 // worklist it must be because a child is dependent on it being converted
354 // and we have a cycle in the graph. Cycles are not allowed in TBAA
355 // graphs.
356 if (!seen.insert(V: current).second)
357 return emitError(loc) << "has cycle in TBAA graph: "
358 << diagMD(node: current, module: llvmModule.get());
359
360 continue;
361 }
362
363 // Otherwise simply import the current node.
364 workList.pop_back();
365
366 FailureOr<std::optional<StringRef>> rootNodeIdentity =
367 getIdentityIfRootNode(current);
368 if (succeeded(result: rootNodeIdentity)) {
369 StringAttr stringAttr = *rootNodeIdentity
370 ? builder.getStringAttr(**rootNodeIdentity)
371 : nullptr;
372 // The root nodes do not have operands, so we can create
373 // the TBAARootAttr on the first walk.
374 tbaaMapping.insert({current, builder.getAttr<TBAARootAttr>(stringAttr)});
375 continue;
376 }
377
378 StringRef identity;
379 SmallVector<TBAAMemberAttr> members;
380 if (std::optional<bool> isValid =
381 isTypeDescriptorNode(current, &identity, &members)) {
382 assert(isValid.value() && "type descriptor node must be valid");
383
384 tbaaMapping.insert({current, builder.getAttr<TBAATypeDescriptorAttr>(
385 identity, members)});
386 continue;
387 }
388
389 TBAATypeDescriptorAttr baseAttr, accessAttr;
390 int64_t offset;
391 bool isConstant;
392 if (std::optional<bool> isValid =
393 isTagNode(current, &baseAttr, &accessAttr, &offset, &isConstant)) {
394 assert(isValid.value() && "access tag node must be valid");
395 tbaaMapping.insert(
396 {current, builder.getAttr<TBAATagAttr>(baseAttr, accessAttr, offset,
397 isConstant)});
398 continue;
399 }
400
401 return emitError(loc) << "unsupported TBAA node format: "
402 << diagMD(node: current, module: llvmModule.get());
403 }
404 return success();
405}
406
407LogicalResult
408ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) {
409 Location loc = mlirModule.getLoc();
410 if (failed(result: loopAnnotationImporter->translateAccessGroup(node, loc)))
411 return emitError(loc) << "unsupported access group node: "
412 << diagMD(node, module: llvmModule.get());
413 return success();
414}
415
416LogicalResult
417ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) {
418 Location loc = mlirModule.getLoc();
419 // Helper that verifies the node has a self reference operand.
420 auto verifySelfRef = [](const llvm::MDNode *node) {
421 return node->getNumOperands() != 0 &&
422 node == dyn_cast<llvm::MDNode>(Val: node->getOperand(I: 0));
423 };
424 // Helper that verifies the given operand is a string or does not exist.
425 auto verifyDescription = [](const llvm::MDNode *node, unsigned idx) {
426 return idx >= node->getNumOperands() ||
427 isa<llvm::MDString>(Val: node->getOperand(I: idx));
428 };
429 // Helper that creates an alias scope domain attribute.
430 auto createAliasScopeDomainOp = [&](const llvm::MDNode *aliasDomain) {
431 StringAttr description = nullptr;
432 if (aliasDomain->getNumOperands() >= 2)
433 if (auto *operand = dyn_cast<llvm::MDString>(Val: aliasDomain->getOperand(I: 1)))
434 description = builder.getStringAttr(operand->getString());
435 return builder.getAttr<AliasScopeDomainAttr>(
436 DistinctAttr::create(builder.getUnitAttr()), description);
437 };
438
439 // Collect the alias scopes and domains to translate them.
440 for (const llvm::MDOperand &operand : node->operands()) {
441 if (const auto *scope = dyn_cast<llvm::MDNode>(Val: operand)) {
442 llvm::AliasScopeNode aliasScope(scope);
443 const llvm::MDNode *domain = aliasScope.getDomain();
444
445 // Verify the scope node points to valid scope metadata which includes
446 // verifying its domain. Perform the verification before looking it up in
447 // the alias scope mapping since it could have been inserted as a domain
448 // node before.
449 if (!verifySelfRef(scope) || !domain || !verifyDescription(scope, 2))
450 return emitError(loc) << "unsupported alias scope node: "
451 << diagMD(node: scope, module: llvmModule.get());
452 if (!verifySelfRef(domain) || !verifyDescription(domain, 1))
453 return emitError(loc) << "unsupported alias domain node: "
454 << diagMD(node: domain, module: llvmModule.get());
455
456 if (aliasScopeMapping.contains(Val: scope))
457 continue;
458
459 // Convert the domain metadata node if it has not been translated before.
460 auto it = aliasScopeMapping.find(Val: aliasScope.getDomain());
461 if (it == aliasScopeMapping.end()) {
462 auto aliasScopeDomainOp = createAliasScopeDomainOp(domain);
463 it = aliasScopeMapping.try_emplace(domain, aliasScopeDomainOp).first;
464 }
465
466 // Convert the scope metadata node if it has not been converted before.
467 StringAttr description = nullptr;
468 if (!aliasScope.getName().empty())
469 description = builder.getStringAttr(aliasScope.getName());
470 auto aliasScopeOp = builder.getAttr<AliasScopeAttr>(
471 DistinctAttr::create(builder.getUnitAttr()),
472 cast<AliasScopeDomainAttr>(it->second), description);
473 aliasScopeMapping.try_emplace(aliasScope.getNode(), aliasScopeOp);
474 }
475 }
476 return success();
477}
478
479FailureOr<SmallVector<AliasScopeAttr>>
480ModuleImport::lookupAliasScopeAttrs(const llvm::MDNode *node) const {
481 SmallVector<AliasScopeAttr> aliasScopes;
482 aliasScopes.reserve(node->getNumOperands());
483 for (const llvm::MDOperand &operand : node->operands()) {
484 auto *node = cast<llvm::MDNode>(Val: operand.get());
485 aliasScopes.push_back(
486 dyn_cast_or_null<AliasScopeAttr>(aliasScopeMapping.lookup(node)));
487 }
488 // Return failure if one of the alias scope lookups failed.
489 if (llvm::is_contained(aliasScopes, nullptr))
490 return failure();
491 return aliasScopes;
492}
493
494void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) {
495 debugIntrinsics.insert(X: intrinsic);
496}
497
498LogicalResult ModuleImport::convertLinkerOptionsMetadata() {
499 for (const llvm::NamedMDNode &named : llvmModule->named_metadata()) {
500 if (named.getName() != "llvm.linker.options")
501 continue;
502 // llvm.linker.options operands are lists of strings.
503 for (const llvm::MDNode *md : named.operands()) {
504 SmallVector<StringRef> options;
505 options.reserve(N: md->getNumOperands());
506 for (const llvm::MDOperand &option : md->operands())
507 options.push_back(Elt: cast<llvm::MDString>(Val: option)->getString());
508 builder.create<LLVM::LinkerOptionsOp>(mlirModule.getLoc(),
509 builder.getStrArrayAttr(options));
510 }
511 }
512 return success();
513}
514
515LogicalResult ModuleImport::convertMetadata() {
516 OpBuilder::InsertionGuard guard(builder);
517 builder.setInsertionPointToEnd(mlirModule.getBody());
518 for (const llvm::Function &func : llvmModule->functions()) {
519 for (const llvm::Instruction &inst : llvm::instructions(F: func)) {
520 // Convert access group metadata nodes.
521 if (llvm::MDNode *node =
522 inst.getMetadata(KindID: llvm::LLVMContext::MD_access_group))
523 if (failed(result: processAccessGroupMetadata(node)))
524 return failure();
525
526 // Convert alias analysis metadata nodes.
527 llvm::AAMDNodes aliasAnalysisNodes = inst.getAAMetadata();
528 if (!aliasAnalysisNodes)
529 continue;
530 if (aliasAnalysisNodes.TBAA)
531 if (failed(result: processTBAAMetadata(node: aliasAnalysisNodes.TBAA)))
532 return failure();
533 if (aliasAnalysisNodes.Scope)
534 if (failed(result: processAliasScopeMetadata(node: aliasAnalysisNodes.Scope)))
535 return failure();
536 if (aliasAnalysisNodes.NoAlias)
537 if (failed(result: processAliasScopeMetadata(node: aliasAnalysisNodes.NoAlias)))
538 return failure();
539 }
540 }
541 if (failed(result: convertLinkerOptionsMetadata()))
542 return failure();
543 return success();
544}
545
546void ModuleImport::processComdat(const llvm::Comdat *comdat) {
547 if (comdatMapping.contains(Val: comdat))
548 return;
549
550 ComdatOp comdatOp = getGlobalComdatOp();
551 OpBuilder::InsertionGuard guard(builder);
552 builder.setInsertionPointToEnd(&comdatOp.getBody().back());
553 auto selectorOp = builder.create<ComdatSelectorOp>(
554 mlirModule.getLoc(), comdat->getName(),
555 convertComdatFromLLVM(comdat->getSelectionKind()));
556 auto symbolRef =
557 SymbolRefAttr::get(builder.getContext(), getGlobalComdatOpName(),
558 FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()));
559 comdatMapping.try_emplace(comdat, symbolRef);
560}
561
562LogicalResult ModuleImport::convertComdats() {
563 for (llvm::GlobalVariable &globalVar : llvmModule->globals())
564 if (globalVar.hasComdat())
565 processComdat(comdat: globalVar.getComdat());
566 for (llvm::Function &func : llvmModule->functions())
567 if (func.hasComdat())
568 processComdat(comdat: func.getComdat());
569 return success();
570}
571
572LogicalResult ModuleImport::convertGlobals() {
573 for (llvm::GlobalVariable &globalVar : llvmModule->globals()) {
574 if (globalVar.getName() == getGlobalCtorsVarName() ||
575 globalVar.getName() == getGlobalDtorsVarName()) {
576 if (failed(result: convertGlobalCtorsAndDtors(globalVar: &globalVar))) {
577 return emitError(UnknownLoc::get(context))
578 << "unhandled global variable: " << diag(globalVar);
579 }
580 continue;
581 }
582 if (failed(result: convertGlobal(globalVar: &globalVar))) {
583 return emitError(UnknownLoc::get(context))
584 << "unhandled global variable: " << diag(globalVar);
585 }
586 }
587 return success();
588}
589
590LogicalResult ModuleImport::convertDataLayout() {
591 Location loc = mlirModule.getLoc();
592 DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout());
593 if (!dataLayoutImporter.getDataLayout())
594 return emitError(loc, message: "cannot translate data layout: ")
595 << dataLayoutImporter.getLastToken();
596
597 for (StringRef token : dataLayoutImporter.getUnhandledTokens())
598 emitWarning(loc, message: "unhandled data layout token: ") << token;
599
600 mlirModule->setAttr(DLTIDialect::kDataLayoutAttrName,
601 dataLayoutImporter.getDataLayout());
602 return success();
603}
604
605LogicalResult ModuleImport::convertFunctions() {
606 for (llvm::Function &func : llvmModule->functions())
607 if (failed(result: processFunction(func: &func)))
608 return failure();
609 return success();
610}
611
612void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
613 Operation *op) {
614 SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
615 inst->getAllMetadataOtherThanDebugLoc(MDs&: allMetadata);
616 for (auto &[kind, node] : allMetadata) {
617 if (!iface.isConvertibleMetadata(kind))
618 continue;
619 if (failed(result: iface.setMetadataAttrs(builder, kind, node, op, moduleImport&: *this))) {
620 if (emitExpensiveWarnings) {
621 Location loc = debugImporter->translateLoc(loc: inst->getDebugLoc());
622 emitWarning(loc) << "unhandled metadata: "
623 << diagMD(node, module: llvmModule.get()) << " on "
624 << diag(value: *inst);
625 }
626 }
627 }
628}
629
630void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
631 Operation *op) const {
632 auto iface = cast<IntegerOverflowFlagsInterface>(op);
633
634 IntegerOverflowFlags value = {};
635 value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap());
636 value =
637 bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
638
639 iface.setOverflowFlags(value);
640}
641
642void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
643 Operation *op) const {
644 auto iface = cast<FastmathFlagsInterface>(op);
645
646 // Even if the imported operation implements the fastmath interface, the
647 // original instruction may not have fastmath flags set. Exit if an
648 // instruction, such as a non floating-point function call, does not have
649 // fastmath flags.
650 if (!isa<llvm::FPMathOperator>(Val: inst))
651 return;
652 llvm::FastMathFlags flags = inst->getFastMathFlags();
653
654 // Set the fastmath bits flag-by-flag.
655 FastmathFlags value = {};
656 value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs());
657 value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs());
658 value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros());
659 value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal());
660 value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract());
661 value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc());
662 value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc());
663 FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value);
664 iface->setAttr(iface.getFastmathAttrName(), attr);
665}
666
667/// Returns if `type` is a scalar integer or floating-point type.
668static bool isScalarType(Type type) {
669 return isa<IntegerType, FloatType>(Val: type);
670}
671
672/// Returns `type` if it is a builtin integer or floating-point vector type that
673/// can be used to create an attribute or nullptr otherwise. If provided,
674/// `arrayShape` is added to the shape of the vector to create an attribute that
675/// matches an array of vectors.
676static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
677 if (!LLVM::isCompatibleVectorType(type))
678 return {};
679
680 llvm::ElementCount numElements = LLVM::getVectorNumElements(type);
681 if (numElements.isScalable()) {
682 emitError(UnknownLoc::get(type.getContext()))
683 << "scalable vectors not supported";
684 return {};
685 }
686
687 // An LLVM dialect vector can only contain scalars.
688 Type elementType = LLVM::getVectorElementType(type);
689 if (!isScalarType(type: elementType))
690 return {};
691
692 SmallVector<int64_t> shape(arrayShape.begin(), arrayShape.end());
693 shape.push_back(Elt: numElements.getKnownMinValue());
694 return VectorType::get(shape, elementType);
695}
696
697Type ModuleImport::getBuiltinTypeForAttr(Type type) {
698 if (!type)
699 return {};
700
701 // Return builtin integer and floating-point types as is.
702 if (isScalarType(type))
703 return type;
704
705 // Return builtin vectors of integer and floating-point types as is.
706 if (Type vectorType = getVectorTypeForAttr(type))
707 return vectorType;
708
709 // Multi-dimensional array types are converted to tensors or vectors,
710 // depending on the innermost type being a scalar or a vector.
711 SmallVector<int64_t> arrayShape;
712 while (auto arrayType = dyn_cast<LLVMArrayType>(type)) {
713 arrayShape.push_back(Elt: arrayType.getNumElements());
714 type = arrayType.getElementType();
715 }
716 if (isScalarType(type))
717 return RankedTensorType::get(arrayShape, type);
718 return getVectorTypeForAttr(type, arrayShape);
719}
720
721/// Returns an integer or float attribute for the provided scalar constant
722/// `constScalar` or nullptr if the conversion fails.
723static TypedAttr getScalarConstantAsAttr(OpBuilder &builder,
724 llvm::Constant *constScalar) {
725 MLIRContext *context = builder.getContext();
726
727 // Convert scalar intergers.
728 if (auto *constInt = dyn_cast<llvm::ConstantInt>(Val: constScalar)) {
729 return builder.getIntegerAttr(
730 IntegerType::get(context, constInt->getBitWidth()),
731 constInt->getValue());
732 }
733
734 // Convert scalar floats.
735 if (auto *constFloat = dyn_cast<llvm::ConstantFP>(Val: constScalar)) {
736 llvm::Type *type = constFloat->getType();
737 FloatType floatType =
738 type->isBFloatTy()
739 ? FloatType::getBF16(ctx: context)
740 : LLVM::detail::getFloatType(context, width: type->getScalarSizeInBits());
741 if (!floatType) {
742 emitError(UnknownLoc::get(builder.getContext()))
743 << "unexpected floating-point type";
744 return {};
745 }
746 return builder.getFloatAttr(floatType, constFloat->getValueAPF());
747 }
748 return {};
749}
750
751/// Returns an integer or float attribute array for the provided constant
752/// sequence `constSequence` or nullptr if the conversion fails.
753static SmallVector<Attribute>
754getSequenceConstantAsAttrs(OpBuilder &builder,
755 llvm::ConstantDataSequential *constSequence) {
756 SmallVector<Attribute> elementAttrs;
757 elementAttrs.reserve(N: constSequence->getNumElements());
758 for (auto idx : llvm::seq<int64_t>(Begin: 0, End: constSequence->getNumElements())) {
759 llvm::Constant *constElement = constSequence->getElementAsConstant(i: idx);
760 elementAttrs.push_back(getScalarConstantAsAttr(builder, constElement));
761 }
762 return elementAttrs;
763}
764
765Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) {
766 // Convert scalar constants.
767 if (Attribute scalarAttr = getScalarConstantAsAttr(builder, constant))
768 return scalarAttr;
769
770 // Convert function references.
771 if (auto *func = dyn_cast<llvm::Function>(constant))
772 return SymbolRefAttr::get(builder.getContext(), func->getName());
773
774 // Returns the static shape of the provided type if possible.
775 auto getConstantShape = [&](llvm::Type *type) {
776 return llvm::dyn_cast_if_present<ShapedType>(
777 getBuiltinTypeForAttr(convertType(type)));
778 };
779
780 // Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte
781 // integer or half/bfloat/float/double values.
782 if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(Val: constant)) {
783 if (constArray->isString())
784 return builder.getStringAttr(constArray->getAsString());
785 auto shape = getConstantShape(constArray->getType());
786 if (!shape)
787 return {};
788 // Convert splat constants to splat elements attributes.
789 auto *constVector = dyn_cast<llvm::ConstantDataVector>(Val: constant);
790 if (constVector && constVector->isSplat()) {
791 // A vector is guaranteed to have at least size one.
792 Attribute splatAttr = getScalarConstantAsAttr(
793 builder, constVector->getElementAsConstant(0));
794 return SplatElementsAttr::get(shape, splatAttr);
795 }
796 // Convert non-splat constants to dense elements attributes.
797 SmallVector<Attribute> elementAttrs =
798 getSequenceConstantAsAttrs(builder, constSequence: constArray);
799 return DenseElementsAttr::get(shape, elementAttrs);
800 }
801
802 // Convert multi-dimensional constant aggregates that store all kinds of
803 // integer and floating-point types.
804 if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(Val: constant)) {
805 auto shape = getConstantShape(constAggregate->getType());
806 if (!shape)
807 return {};
808 // Collect the aggregate elements in depths first order.
809 SmallVector<Attribute> elementAttrs;
810 SmallVector<llvm::Constant *> workList = {constAggregate};
811 while (!workList.empty()) {
812 llvm::Constant *current = workList.pop_back_val();
813 // Append any nested aggregates in reverse order to ensure the head
814 // element of the nested aggregates is at the back of the work list.
815 if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(Val: current)) {
816 for (auto idx :
817 reverse(C: llvm::seq<int64_t>(Begin: 0, End: constAggregate->getNumOperands())))
818 workList.push_back(Elt: constAggregate->getAggregateElement(Elt: idx));
819 continue;
820 }
821 // Append the elements of nested constant arrays or vectors that store
822 // 1/2/4/8-byte integer or half/bfloat/float/double values.
823 if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(Val: current)) {
824 SmallVector<Attribute> attrs =
825 getSequenceConstantAsAttrs(builder, constSequence: constArray);
826 elementAttrs.append(in_start: attrs.begin(), in_end: attrs.end());
827 continue;
828 }
829 // Append nested scalar constants that store all kinds of integer and
830 // floating-point types.
831 if (Attribute scalarAttr = getScalarConstantAsAttr(builder, current)) {
832 elementAttrs.push_back(Elt: scalarAttr);
833 continue;
834 }
835 // Bail if the aggregate contains a unsupported constant type such as a
836 // constant expression.
837 return {};
838 }
839 return DenseElementsAttr::get(shape, elementAttrs);
840 }
841
842 // Convert zero aggregates.
843 if (auto *constZero = dyn_cast<llvm::ConstantAggregateZero>(Val: constant)) {
844 auto shape = llvm::dyn_cast_if_present<ShapedType>(
845 getBuiltinTypeForAttr(convertType(constZero->getType())));
846 if (!shape)
847 return {};
848 // Convert zero aggregates with a static shape to splat elements attributes.
849 Attribute splatAttr = builder.getZeroAttr(type: shape.getElementType());
850 assert(splatAttr && "expected non-null zero attribute for scalar types");
851 return SplatElementsAttr::get(shape, splatAttr);
852 }
853 return {};
854}
855
856LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
857 // Insert the global after the last one or at the start of the module.
858 OpBuilder::InsertionGuard guard(builder);
859 if (!globalInsertionOp)
860 builder.setInsertionPointToStart(mlirModule.getBody());
861 else
862 builder.setInsertionPointAfter(globalInsertionOp);
863
864 Attribute valueAttr;
865 if (globalVar->hasInitializer())
866 valueAttr = getConstantAsAttr(constant: globalVar->getInitializer());
867 Type type = convertType(type: globalVar->getValueType());
868
869 uint64_t alignment = 0;
870 llvm::MaybeAlign maybeAlign = globalVar->getAlign();
871 if (maybeAlign.has_value()) {
872 llvm::Align align = *maybeAlign;
873 alignment = align.value();
874 }
875
876 // Get the global expression associated with this global variable and convert
877 // it.
878 DIGlobalVariableExpressionAttr globalExpressionAttr;
879 SmallVector<llvm::DIGlobalVariableExpression *> globalExpressions;
880 globalVar->getDebugInfo(GVs&: globalExpressions);
881
882 // There should only be a single global expression.
883 if (!globalExpressions.empty())
884 globalExpressionAttr =
885 debugImporter->translateGlobalVariableExpression(globalExpressions[0]);
886
887 GlobalOp globalOp = builder.create<GlobalOp>(
888 mlirModule.getLoc(), type, globalVar->isConstant(),
889 convertLinkageFromLLVM(globalVar->getLinkage()), globalVar->getName(),
890 valueAttr, alignment, /*addr_space=*/globalVar->getAddressSpace(),
891 /*dso_local=*/globalVar->isDSOLocal(),
892 /*thread_local=*/globalVar->isThreadLocal(), /*comdat=*/SymbolRefAttr(),
893 /*attrs=*/ArrayRef<NamedAttribute>(), /*dbgExpr=*/globalExpressionAttr);
894 globalInsertionOp = globalOp;
895
896 if (globalVar->hasInitializer() && !valueAttr) {
897 clearRegionState();
898 Block *block = builder.createBlock(&globalOp.getInitializerRegion());
899 setConstantInsertionPointToStart(block);
900 FailureOr<Value> initializer =
901 convertConstantExpr(constant: globalVar->getInitializer());
902 if (failed(result: initializer))
903 return failure();
904 builder.create<ReturnOp>(globalOp.getLoc(), *initializer);
905 }
906 if (globalVar->hasAtLeastLocalUnnamedAddr()) {
907 globalOp.setUnnamedAddr(
908 convertUnnamedAddrFromLLVM(globalVar->getUnnamedAddr()));
909 }
910 if (globalVar->hasSection())
911 globalOp.setSection(globalVar->getSection());
912 globalOp.setVisibility_(
913 convertVisibilityFromLLVM(globalVar->getVisibility()));
914
915 if (globalVar->hasComdat())
916 globalOp.setComdatAttr(comdatMapping.lookup(globalVar->getComdat()));
917
918 return success();
919}
920
921LogicalResult
922ModuleImport::convertGlobalCtorsAndDtors(llvm::GlobalVariable *globalVar) {
923 if (!globalVar->hasInitializer() || !globalVar->hasAppendingLinkage())
924 return failure();
925 auto *initializer =
926 dyn_cast<llvm::ConstantArray>(Val: globalVar->getInitializer());
927 if (!initializer)
928 return failure();
929
930 SmallVector<Attribute> funcs;
931 SmallVector<int32_t> priorities;
932 for (llvm::Value *operand : initializer->operands()) {
933 auto *aggregate = dyn_cast<llvm::ConstantAggregate>(Val: operand);
934 if (!aggregate || aggregate->getNumOperands() != 3)
935 return failure();
936
937 auto *priority = dyn_cast<llvm::ConstantInt>(Val: aggregate->getOperand(i_nocapture: 0));
938 auto *func = dyn_cast<llvm::Function>(Val: aggregate->getOperand(i_nocapture: 1));
939 auto *data = dyn_cast<llvm::Constant>(Val: aggregate->getOperand(i_nocapture: 2));
940 if (!priority || !func || !data)
941 return failure();
942
943 // GlobalCtorsOps and GlobalDtorsOps do not support non-null data fields.
944 if (!data->isNullValue())
945 return failure();
946
947 funcs.push_back(FlatSymbolRefAttr::get(ctx: context, value: func->getName()));
948 priorities.push_back(Elt: priority->getValue().getZExtValue());
949 }
950
951 OpBuilder::InsertionGuard guard(builder);
952 if (!globalInsertionOp)
953 builder.setInsertionPointToStart(mlirModule.getBody());
954 else
955 builder.setInsertionPointAfter(globalInsertionOp);
956
957 if (globalVar->getName() == getGlobalCtorsVarName()) {
958 globalInsertionOp = builder.create<LLVM::GlobalCtorsOp>(
959 mlirModule.getLoc(), builder.getArrayAttr(funcs),
960 builder.getI32ArrayAttr(priorities));
961 return success();
962 }
963 globalInsertionOp = builder.create<LLVM::GlobalDtorsOp>(
964 mlirModule.getLoc(), builder.getArrayAttr(funcs),
965 builder.getI32ArrayAttr(priorities));
966 return success();
967}
968
969SetVector<llvm::Constant *>
970ModuleImport::getConstantsToConvert(llvm::Constant *constant) {
971 // Return the empty set if the constant has been translated before.
972 if (valueMapping.contains(Val: constant))
973 return {};
974
975 // Traverse the constants in post-order and stop the traversal if a constant
976 // already has a `valueMapping` from an earlier constant translation or if the
977 // constant is traversed a second time.
978 SetVector<llvm::Constant *> orderedSet;
979 SetVector<llvm::Constant *> workList;
980 DenseMap<llvm::Constant *, SmallVector<llvm::Constant *>> adjacencyLists;
981 workList.insert(X: constant);
982 while (!workList.empty()) {
983 llvm::Constant *current = workList.back();
984 // Collect all dependencies of the current constant and add them to the
985 // adjacency list if none has been computed before.
986 auto adjacencyIt = adjacencyLists.find(Val: current);
987 if (adjacencyIt == adjacencyLists.end()) {
988 adjacencyIt = adjacencyLists.try_emplace(Key: current).first;
989 // Add all constant operands to the adjacency list and skip any other
990 // values such as basic block addresses.
991 for (llvm::Value *operand : current->operands())
992 if (auto *constDependency = dyn_cast<llvm::Constant>(Val: operand))
993 adjacencyIt->getSecond().push_back(Elt: constDependency);
994 // Use the getElementValue method to add the dependencies of zero
995 // initialized aggregate constants since they do not take any operands.
996 if (auto *constAgg = dyn_cast<llvm::ConstantAggregateZero>(Val: current)) {
997 unsigned numElements = constAgg->getElementCount().getFixedValue();
998 for (unsigned i = 0, e = numElements; i != e; ++i)
999 adjacencyIt->getSecond().push_back(Elt: constAgg->getElementValue(Idx: i));
1000 }
1001 }
1002 // Add the current constant to the `orderedSet` of the traversed nodes if
1003 // all its dependencies have been traversed before. Additionally, remove the
1004 // constant from the `workList` and continue the traversal.
1005 if (adjacencyIt->getSecond().empty()) {
1006 orderedSet.insert(X: current);
1007 workList.pop_back();
1008 continue;
1009 }
1010 // Add the next dependency from the adjacency list to the `workList` and
1011 // continue the traversal. Remove the dependency from the adjacency list to
1012 // mark that it has been processed. Only enqueue the dependency if it has no
1013 // `valueMapping` from an earlier translation and if it has not been
1014 // enqueued before.
1015 llvm::Constant *dependency = adjacencyIt->getSecond().pop_back_val();
1016 if (valueMapping.contains(Val: dependency) || workList.contains(key: dependency) ||
1017 orderedSet.contains(key: dependency))
1018 continue;
1019 workList.insert(X: dependency);
1020 }
1021
1022 return orderedSet;
1023}
1024
1025FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
1026 Location loc = UnknownLoc::get(context);
1027
1028 // Convert constants that can be represented as attributes.
1029 if (Attribute attr = getConstantAsAttr(constant)) {
1030 Type type = convertType(type: constant->getType());
1031 if (auto symbolRef = dyn_cast<FlatSymbolRefAttr>(attr)) {
1032 return builder.create<AddressOfOp>(loc, type, symbolRef.getValue())
1033 .getResult();
1034 }
1035 return builder.create<ConstantOp>(loc, type, attr).getResult();
1036 }
1037
1038 // Convert null pointer constants.
1039 if (auto *nullPtr = dyn_cast<llvm::ConstantPointerNull>(Val: constant)) {
1040 Type type = convertType(type: nullPtr->getType());
1041 return builder.create<ZeroOp>(loc, type).getResult();
1042 }
1043
1044 // Convert none token constants.
1045 if (isa<llvm::ConstantTokenNone>(Val: constant)) {
1046 return builder.create<NoneTokenOp>(loc).getResult();
1047 }
1048
1049 // Convert poison.
1050 if (auto *poisonVal = dyn_cast<llvm::PoisonValue>(Val: constant)) {
1051 Type type = convertType(type: poisonVal->getType());
1052 return builder.create<PoisonOp>(loc, type).getResult();
1053 }
1054
1055 // Convert undef.
1056 if (auto *undefVal = dyn_cast<llvm::UndefValue>(Val: constant)) {
1057 Type type = convertType(type: undefVal->getType());
1058 return builder.create<UndefOp>(loc, type).getResult();
1059 }
1060
1061 // Convert global variable accesses.
1062 if (auto *globalVar = dyn_cast<llvm::GlobalVariable>(Val: constant)) {
1063 Type type = convertType(type: globalVar->getType());
1064 auto symbolRef = FlatSymbolRefAttr::get(ctx: context, value: globalVar->getName());
1065 return builder.create<AddressOfOp>(loc, type, symbolRef).getResult();
1066 }
1067
1068 // Convert constant expressions.
1069 if (auto *constExpr = dyn_cast<llvm::ConstantExpr>(Val: constant)) {
1070 // Convert the constant expression to a temporary LLVM instruction and
1071 // translate it using the `processInstruction` method. Delete the
1072 // instruction after the translation and remove it from `valueMapping`,
1073 // since later calls to `getAsInstruction` may return the same address
1074 // resulting in a conflicting `valueMapping` entry.
1075 llvm::Instruction *inst = constExpr->getAsInstruction();
1076 auto guard = llvm::make_scope_exit(F: [&]() {
1077 assert(!noResultOpMapping.contains(inst) &&
1078 "expected constant expression to return a result");
1079 valueMapping.erase(Val: inst);
1080 inst->deleteValue();
1081 });
1082 // Note: `processInstruction` does not call `convertConstant` recursively
1083 // since all constant dependencies have been converted before.
1084 assert(llvm::all_of(inst->operands(), [&](llvm::Value *value) {
1085 return valueMapping.contains(value);
1086 }));
1087 if (failed(result: processInstruction(inst)))
1088 return failure();
1089 return lookupValue(value: inst);
1090 }
1091
1092 // Convert aggregate constants.
1093 if (isa<llvm::ConstantAggregate>(Val: constant) ||
1094 isa<llvm::ConstantAggregateZero>(Val: constant)) {
1095 // Lookup the aggregate elements that have been converted before.
1096 SmallVector<Value> elementValues;
1097 if (auto *constAgg = dyn_cast<llvm::ConstantAggregate>(Val: constant)) {
1098 elementValues.reserve(N: constAgg->getNumOperands());
1099 for (llvm::Value *operand : constAgg->operands())
1100 elementValues.push_back(Elt: lookupValue(value: operand));
1101 }
1102 if (auto *constAgg = dyn_cast<llvm::ConstantAggregateZero>(Val: constant)) {
1103 unsigned numElements = constAgg->getElementCount().getFixedValue();
1104 elementValues.reserve(N: numElements);
1105 for (unsigned i = 0, e = numElements; i != e; ++i)
1106 elementValues.push_back(Elt: lookupValue(value: constAgg->getElementValue(Idx: i)));
1107 }
1108 assert(llvm::count(elementValues, nullptr) == 0 &&
1109 "expected all elements have been converted before");
1110
1111 // Generate an UndefOp as root value and insert the aggregate elements.
1112 Type rootType = convertType(type: constant->getType());
1113 bool isArrayOrStruct = isa<LLVMArrayType, LLVMStructType>(rootType);
1114 assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) &&
1115 "unrecognized aggregate type");
1116 Value root = builder.create<UndefOp>(loc, rootType);
1117 for (const auto &it : llvm::enumerate(First&: elementValues)) {
1118 if (isArrayOrStruct) {
1119 root = builder.create<InsertValueOp>(loc, root, it.value(), it.index());
1120 } else {
1121 Attribute indexAttr = builder.getI32IntegerAttr(it.index());
1122 Value indexValue =
1123 builder.create<ConstantOp>(loc, builder.getI32Type(), indexAttr);
1124 root = builder.create<InsertElementOp>(loc, rootType, root, it.value(),
1125 indexValue);
1126 }
1127 }
1128 return root;
1129 }
1130
1131 if (auto *constTargetNone = dyn_cast<llvm::ConstantTargetNone>(Val: constant)) {
1132 LLVMTargetExtType targetExtType =
1133 cast<LLVMTargetExtType>(convertType(constTargetNone->getType()));
1134 assert(targetExtType.hasProperty(LLVMTargetExtType::HasZeroInit) &&
1135 "target extension type does not support zero-initialization");
1136 // Create llvm.mlir.zero operation to represent zero-initialization of
1137 // target extension type.
1138 return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes();
1139 }
1140
1141 StringRef error = "";
1142 if (isa<llvm::BlockAddress>(Val: constant))
1143 error = " since blockaddress(...) is unsupported";
1144
1145 return emitError(loc) << "unhandled constant: " << diag(value: *constant) << error;
1146}
1147
1148FailureOr<Value> ModuleImport::convertConstantExpr(llvm::Constant *constant) {
1149 // Only call the function for constants that have not been translated before
1150 // since it updates the constant insertion point assuming the converted
1151 // constant has been introduced at the end of the constant section.
1152 assert(!valueMapping.contains(constant) &&
1153 "expected constant has not been converted before");
1154 assert(constantInsertionBlock &&
1155 "expected the constant insertion block to be non-null");
1156
1157 // Insert the constant after the last one or at the start of the entry block.
1158 OpBuilder::InsertionGuard guard(builder);
1159 if (!constantInsertionOp)
1160 builder.setInsertionPointToStart(constantInsertionBlock);
1161 else
1162 builder.setInsertionPointAfter(constantInsertionOp);
1163
1164 // Convert all constants of the expression and add them to `valueMapping`.
1165 SetVector<llvm::Constant *> constantsToConvert =
1166 getConstantsToConvert(constant);
1167 for (llvm::Constant *constantToConvert : constantsToConvert) {
1168 FailureOr<Value> converted = convertConstant(constant: constantToConvert);
1169 if (failed(result: converted))
1170 return failure();
1171 mapValue(llvm: constantToConvert, mlir: *converted);
1172 }
1173
1174 // Update the constant insertion point and return the converted constant.
1175 Value result = lookupValue(value: constant);
1176 constantInsertionOp = result.getDefiningOp();
1177 return result;
1178}
1179
1180FailureOr<Value> ModuleImport::convertValue(llvm::Value *value) {
1181 assert(!isa<llvm::MetadataAsValue>(value) &&
1182 "expected value to not be metadata");
1183
1184 // Return the mapped value if it has been converted before.
1185 auto it = valueMapping.find(Val: value);
1186 if (it != valueMapping.end())
1187 return it->getSecond();
1188
1189 // Convert constants such as immediate values that have no mapping yet.
1190 if (auto *constant = dyn_cast<llvm::Constant>(Val: value))
1191 return convertConstantExpr(constant);
1192
1193 Location loc = UnknownLoc::get(context);
1194 if (auto *inst = dyn_cast<llvm::Instruction>(Val: value))
1195 loc = translateLoc(loc: inst->getDebugLoc());
1196 return emitError(loc) << "unhandled value: " << diag(value: *value);
1197}
1198
1199FailureOr<Value> ModuleImport::convertMetadataValue(llvm::Value *value) {
1200 // A value may be wrapped as metadata, for example, when passed to a debug
1201 // intrinsic. Unwrap these values before the conversion.
1202 auto *nodeAsVal = dyn_cast<llvm::MetadataAsValue>(Val: value);
1203 if (!nodeAsVal)
1204 return failure();
1205 auto *node = dyn_cast<llvm::ValueAsMetadata>(Val: nodeAsVal->getMetadata());
1206 if (!node)
1207 return failure();
1208 value = node->getValue();
1209
1210 // Return the mapped value if it has been converted before.
1211 auto it = valueMapping.find(Val: value);
1212 if (it != valueMapping.end())
1213 return it->getSecond();
1214
1215 // Convert constants such as immediate values that have no mapping yet.
1216 if (auto *constant = dyn_cast<llvm::Constant>(Val: value))
1217 return convertConstantExpr(constant);
1218 return failure();
1219}
1220
1221FailureOr<SmallVector<Value>>
1222ModuleImport::convertValues(ArrayRef<llvm::Value *> values) {
1223 SmallVector<Value> remapped;
1224 remapped.reserve(N: values.size());
1225 for (llvm::Value *value : values) {
1226 FailureOr<Value> converted = convertValue(value);
1227 if (failed(result: converted))
1228 return failure();
1229 remapped.push_back(Elt: *converted);
1230 }
1231 return remapped;
1232}
1233
1234LogicalResult ModuleImport::convertIntrinsicArguments(
1235 ArrayRef<llvm::Value *> values, ArrayRef<unsigned> immArgPositions,
1236 ArrayRef<StringLiteral> immArgAttrNames, SmallVectorImpl<Value> &valuesOut,
1237 SmallVectorImpl<NamedAttribute> &attrsOut) {
1238 assert(immArgPositions.size() == immArgAttrNames.size() &&
1239 "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
1240 "length");
1241
1242 SmallVector<llvm::Value *> operands(values);
1243 for (auto [immArgPos, immArgName] :
1244 llvm::zip(t&: immArgPositions, u&: immArgAttrNames)) {
1245 auto &value = operands[immArgPos];
1246 auto *constant = llvm::cast<llvm::Constant>(Val: value);
1247 auto attr = getScalarConstantAsAttr(builder, constant);
1248 assert(attr && attr.getType().isIntOrFloat() &&
1249 "expected immarg to be float or integer constant");
1250 auto nameAttr = StringAttr::get(attr.getContext(), immArgName);
1251 attrsOut.push_back(Elt: {nameAttr, attr});
1252 // Mark matched attribute values as null (so they can be removed below).
1253 value = nullptr;
1254 }
1255
1256 for (llvm::Value *value : operands) {
1257 if (!value)
1258 continue;
1259 auto mlirValue = convertValue(value);
1260 if (failed(result: mlirValue))
1261 return failure();
1262 valuesOut.push_back(Elt: *mlirValue);
1263 }
1264
1265 return success();
1266}
1267
1268IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
1269 IntegerAttr integerAttr;
1270 FailureOr<Value> converted = convertValue(value);
1271 bool success = succeeded(result: converted) &&
1272 matchPattern(*converted, m_Constant(&integerAttr));
1273 assert(success && "expected a constant integer value");
1274 (void)success;
1275 return integerAttr;
1276}
1277
1278FloatAttr ModuleImport::matchFloatAttr(llvm::Value *value) {
1279 FloatAttr floatAttr;
1280 FailureOr<Value> converted = convertValue(value);
1281 bool success =
1282 succeeded(result: converted) && matchPattern(*converted, m_Constant(&floatAttr));
1283 assert(success && "expected a constant float value");
1284 (void)success;
1285 return floatAttr;
1286}
1287
1288DILocalVariableAttr ModuleImport::matchLocalVariableAttr(llvm::Value *value) {
1289 auto *nodeAsVal = cast<llvm::MetadataAsValue>(Val: value);
1290 auto *node = cast<llvm::DILocalVariable>(Val: nodeAsVal->getMetadata());
1291 return debugImporter->translate(node);
1292}
1293
1294DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) {
1295 auto *nodeAsVal = cast<llvm::MetadataAsValue>(Val: value);
1296 auto *node = cast<llvm::DILabel>(Val: nodeAsVal->getMetadata());
1297 return debugImporter->translate(node);
1298}
1299
1300FPExceptionBehaviorAttr
1301ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) {
1302 auto *metadata = cast<llvm::MetadataAsValue>(Val: value);
1303 auto *mdstr = cast<llvm::MDString>(Val: metadata->getMetadata());
1304 std::optional<llvm::fp::ExceptionBehavior> optLLVM =
1305 llvm::convertStrToExceptionBehavior(mdstr->getString());
1306 assert(optLLVM && "Expecting FP exception behavior");
1307 return builder.getAttr<FPExceptionBehaviorAttr>(
1308 convertFPExceptionBehaviorFromLLVM(*optLLVM));
1309}
1310
1311RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) {
1312 auto *metadata = cast<llvm::MetadataAsValue>(Val: value);
1313 auto *mdstr = cast<llvm::MDString>(Val: metadata->getMetadata());
1314 std::optional<llvm::RoundingMode> optLLVM =
1315 llvm::convertStrToRoundingMode(mdstr->getString());
1316 assert(optLLVM && "Expecting rounding mode");
1317 return builder.getAttr<RoundingModeAttr>(
1318 convertRoundingModeFromLLVM(*optLLVM));
1319}
1320
1321FailureOr<SmallVector<AliasScopeAttr>>
1322ModuleImport::matchAliasScopeAttrs(llvm::Value *value) {
1323 auto *nodeAsVal = cast<llvm::MetadataAsValue>(Val: value);
1324 auto *node = cast<llvm::MDNode>(Val: nodeAsVal->getMetadata());
1325 return lookupAliasScopeAttrs(node);
1326}
1327
1328Location ModuleImport::translateLoc(llvm::DILocation *loc) {
1329 return debugImporter->translateLoc(loc);
1330}
1331
1332LogicalResult
1333ModuleImport::convertBranchArgs(llvm::Instruction *branch,
1334 llvm::BasicBlock *target,
1335 SmallVectorImpl<Value> &blockArguments) {
1336 for (auto inst = target->begin(); isa<llvm::PHINode>(Val: inst); ++inst) {
1337 auto *phiInst = cast<llvm::PHINode>(Val: &*inst);
1338 llvm::Value *value = phiInst->getIncomingValueForBlock(BB: branch->getParent());
1339 FailureOr<Value> converted = convertValue(value);
1340 if (failed(result: converted))
1341 return failure();
1342 blockArguments.push_back(Elt: *converted);
1343 }
1344 return success();
1345}
1346
1347LogicalResult
1348ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
1349 SmallVectorImpl<Type> &types,
1350 SmallVectorImpl<Value> &operands) {
1351 if (!callInst->getType()->isVoidTy())
1352 types.push_back(Elt: convertType(type: callInst->getType()));
1353
1354 if (!callInst->getCalledFunction()) {
1355 FailureOr<Value> called = convertValue(value: callInst->getCalledOperand());
1356 if (failed(result: called))
1357 return failure();
1358 operands.push_back(Elt: *called);
1359 }
1360 SmallVector<llvm::Value *> args(callInst->args());
1361 FailureOr<SmallVector<Value>> arguments = convertValues(values: args);
1362 if (failed(result: arguments))
1363 return failure();
1364 llvm::append_range(C&: operands, R&: *arguments);
1365 return success();
1366}
1367
1368LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) {
1369 if (succeeded(result: iface.convertIntrinsic(builder, inst, moduleImport&: *this)))
1370 return success();
1371
1372 Location loc = translateLoc(loc: inst->getDebugLoc());
1373 return emitError(loc) << "unhandled intrinsic: " << diag(value: *inst);
1374}
1375
1376LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
1377 // Convert all instructions that do not provide an MLIR builder.
1378 Location loc = translateLoc(loc: inst->getDebugLoc());
1379 if (inst->getOpcode() == llvm::Instruction::Br) {
1380 auto *brInst = cast<llvm::BranchInst>(Val: inst);
1381
1382 SmallVector<Block *> succBlocks;
1383 SmallVector<SmallVector<Value>> succBlockArgs;
1384 for (auto i : llvm::seq<unsigned>(Begin: 0, End: brInst->getNumSuccessors())) {
1385 llvm::BasicBlock *succ = brInst->getSuccessor(i);
1386 SmallVector<Value> blockArgs;
1387 if (failed(result: convertBranchArgs(branch: brInst, target: succ, blockArguments&: blockArgs)))
1388 return failure();
1389 succBlocks.push_back(Elt: lookupBlock(block: succ));
1390 succBlockArgs.push_back(Elt: blockArgs);
1391 }
1392
1393 if (!brInst->isConditional()) {
1394 auto brOp = builder.create<LLVM::BrOp>(loc, succBlockArgs.front(),
1395 succBlocks.front());
1396 mapNoResultOp(inst, brOp);
1397 return success();
1398 }
1399 FailureOr<Value> condition = convertValue(value: brInst->getCondition());
1400 if (failed(result: condition))
1401 return failure();
1402 auto condBrOp = builder.create<LLVM::CondBrOp>(
1403 loc, *condition, succBlocks.front(), succBlockArgs.front(),
1404 succBlocks.back(), succBlockArgs.back());
1405 mapNoResultOp(inst, condBrOp);
1406 return success();
1407 }
1408 if (inst->getOpcode() == llvm::Instruction::Switch) {
1409 auto *swInst = cast<llvm::SwitchInst>(Val: inst);
1410 // Process the condition value.
1411 FailureOr<Value> condition = convertValue(value: swInst->getCondition());
1412 if (failed(result: condition))
1413 return failure();
1414 SmallVector<Value> defaultBlockArgs;
1415 // Process the default case.
1416 llvm::BasicBlock *defaultBB = swInst->getDefaultDest();
1417 if (failed(result: convertBranchArgs(branch: swInst, target: defaultBB, blockArguments&: defaultBlockArgs)))
1418 return failure();
1419
1420 // Process the cases.
1421 unsigned numCases = swInst->getNumCases();
1422 SmallVector<SmallVector<Value>> caseOperands(numCases);
1423 SmallVector<ValueRange> caseOperandRefs(numCases);
1424 SmallVector<APInt> caseValues(numCases);
1425 SmallVector<Block *> caseBlocks(numCases);
1426 for (const auto &it : llvm::enumerate(First: swInst->cases())) {
1427 const llvm::SwitchInst::CaseHandle &caseHandle = it.value();
1428 llvm::BasicBlock *succBB = caseHandle.getCaseSuccessor();
1429 if (failed(result: convertBranchArgs(branch: swInst, target: succBB, blockArguments&: caseOperands[it.index()])))
1430 return failure();
1431 caseOperandRefs[it.index()] = caseOperands[it.index()];
1432 caseValues[it.index()] = caseHandle.getCaseValue()->getValue();
1433 caseBlocks[it.index()] = lookupBlock(block: succBB);
1434 }
1435
1436 auto switchOp = builder.create<SwitchOp>(
1437 loc, *condition, lookupBlock(defaultBB), defaultBlockArgs, caseValues,
1438 caseBlocks, caseOperandRefs);
1439 mapNoResultOp(inst, switchOp);
1440 return success();
1441 }
1442 if (inst->getOpcode() == llvm::Instruction::PHI) {
1443 Type type = convertType(type: inst->getType());
1444 mapValue(llvm: inst, mlir: builder.getInsertionBlock()->addArgument(
1445 type, loc: translateLoc(loc: inst->getDebugLoc())));
1446 return success();
1447 }
1448 if (inst->getOpcode() == llvm::Instruction::Call) {
1449 auto *callInst = cast<llvm::CallInst>(Val: inst);
1450
1451 SmallVector<Type> types;
1452 SmallVector<Value> operands;
1453 if (failed(result: convertCallTypeAndOperands(callInst, types, operands)))
1454 return failure();
1455
1456 auto funcTy =
1457 dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
1458 if (!funcTy)
1459 return failure();
1460
1461 CallOp callOp;
1462
1463 if (llvm::Function *callee = callInst->getCalledFunction()) {
1464 callOp = builder.create<CallOp>(
1465 loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
1466 operands);
1467 } else {
1468 callOp = builder.create<CallOp>(loc, funcTy, operands);
1469 }
1470 callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1471 setFastmathFlagsAttr(inst, op: callOp);
1472 if (!callInst->getType()->isVoidTy())
1473 mapValue(inst, callOp.getResult());
1474 else
1475 mapNoResultOp(inst, callOp);
1476 return success();
1477 }
1478 if (inst->getOpcode() == llvm::Instruction::LandingPad) {
1479 auto *lpInst = cast<llvm::LandingPadInst>(Val: inst);
1480
1481 SmallVector<Value> operands;
1482 operands.reserve(N: lpInst->getNumClauses());
1483 for (auto i : llvm::seq<unsigned>(Begin: 0, End: lpInst->getNumClauses())) {
1484 FailureOr<Value> operand = convertValue(value: lpInst->getClause(Idx: i));
1485 if (failed(result: operand))
1486 return failure();
1487 operands.push_back(Elt: *operand);
1488 }
1489
1490 Type type = convertType(type: lpInst->getType());
1491 auto lpOp =
1492 builder.create<LandingpadOp>(loc, type, lpInst->isCleanup(), operands);
1493 mapValue(inst, lpOp);
1494 return success();
1495 }
1496 if (inst->getOpcode() == llvm::Instruction::Invoke) {
1497 auto *invokeInst = cast<llvm::InvokeInst>(Val: inst);
1498
1499 SmallVector<Type> types;
1500 SmallVector<Value> operands;
1501 if (failed(result: convertCallTypeAndOperands(callInst: invokeInst, types, operands)))
1502 return failure();
1503
1504 // Check whether the invoke result is an argument to the normal destination
1505 // block.
1506 bool invokeResultUsedInPhi = llvm::any_of(
1507 Range: invokeInst->getNormalDest()->phis(), P: [&](const llvm::PHINode &phi) {
1508 return phi.getIncomingValueForBlock(BB: invokeInst->getParent()) ==
1509 invokeInst;
1510 });
1511
1512 Block *normalDest = lookupBlock(block: invokeInst->getNormalDest());
1513 Block *directNormalDest = normalDest;
1514 if (invokeResultUsedInPhi) {
1515 // The invoke result cannot be an argument to the normal destination
1516 // block, as that would imply using the invoke operation result in its
1517 // definition, so we need to create a dummy block to serve as an
1518 // intermediate destination.
1519 OpBuilder::InsertionGuard g(builder);
1520 directNormalDest = builder.createBlock(insertBefore: normalDest);
1521 }
1522
1523 SmallVector<Value> unwindArgs;
1524 if (failed(result: convertBranchArgs(branch: invokeInst, target: invokeInst->getUnwindDest(),
1525 blockArguments&: unwindArgs)))
1526 return failure();
1527
1528 auto funcTy =
1529 dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
1530 if (!funcTy)
1531 return failure();
1532
1533 // Create the invoke operation. Normal destination block arguments will be
1534 // added later on to handle the case in which the operation result is
1535 // included in this list.
1536 InvokeOp invokeOp;
1537 if (llvm::Function *callee = invokeInst->getCalledFunction()) {
1538 invokeOp = builder.create<InvokeOp>(
1539 loc, funcTy,
1540 SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
1541 directNormalDest, ValueRange(),
1542 lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1543 } else {
1544 invokeOp = builder.create<InvokeOp>(
1545 loc, funcTy, /*callee=*/nullptr, operands, directNormalDest,
1546 ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1547 }
1548 invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv()));
1549 if (!invokeInst->getType()->isVoidTy())
1550 mapValue(inst, invokeOp.getResults().front());
1551 else
1552 mapNoResultOp(inst, invokeOp);
1553
1554 SmallVector<Value> normalArgs;
1555 if (failed(result: convertBranchArgs(branch: invokeInst, target: invokeInst->getNormalDest(),
1556 blockArguments&: normalArgs)))
1557 return failure();
1558
1559 if (invokeResultUsedInPhi) {
1560 // The dummy normal dest block will just host an unconditional branch
1561 // instruction to the normal destination block passing the required block
1562 // arguments (including the invoke operation's result).
1563 OpBuilder::InsertionGuard g(builder);
1564 builder.setInsertionPointToStart(directNormalDest);
1565 builder.create<LLVM::BrOp>(loc, normalArgs, normalDest);
1566 } else {
1567 // If the invoke operation's result is not a block argument to the normal
1568 // destination block, just add the block arguments as usual.
1569 assert(llvm::none_of(
1570 normalArgs,
1571 [&](Value val) { return val.getDefiningOp() == invokeOp; }) &&
1572 "An llvm.invoke operation cannot pass its result as a block "
1573 "argument.");
1574 invokeOp.getNormalDestOperandsMutable().append(normalArgs);
1575 }
1576
1577 return success();
1578 }
1579 if (inst->getOpcode() == llvm::Instruction::GetElementPtr) {
1580 auto *gepInst = cast<llvm::GetElementPtrInst>(Val: inst);
1581 Type sourceElementType = convertType(type: gepInst->getSourceElementType());
1582 FailureOr<Value> basePtr = convertValue(value: gepInst->getOperand(i_nocapture: 0));
1583 if (failed(result: basePtr))
1584 return failure();
1585
1586 // Treat every indices as dynamic since GEPOp::build will refine those
1587 // indices into static attributes later. One small downside of this
1588 // approach is that many unused `llvm.mlir.constant` would be emitted
1589 // at first place.
1590 SmallVector<GEPArg> indices;
1591 for (llvm::Value *operand : llvm::drop_begin(RangeOrContainer: gepInst->operand_values())) {
1592 FailureOr<Value> index = convertValue(value: operand);
1593 if (failed(result: index))
1594 return failure();
1595 indices.push_back(Elt: *index);
1596 }
1597
1598 Type type = convertType(type: inst->getType());
1599 auto gepOp = builder.create<GEPOp>(loc, type, sourceElementType, *basePtr,
1600 indices, gepInst->isInBounds());
1601 mapValue(inst, gepOp);
1602 return success();
1603 }
1604
1605 // Convert all instructions that have an mlirBuilder.
1606 if (succeeded(result: convertInstructionImpl(odsBuilder&: builder, inst, moduleImport&: *this, iface)))
1607 return success();
1608
1609 return emitError(loc) << "unhandled instruction: " << diag(value: *inst);
1610}
1611
1612LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) {
1613 // FIXME: Support uses of SubtargetData.
1614 // FIXME: Add support for call / operand attributes.
1615 // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch,
1616 // callbr, vaarg, catchpad, cleanuppad instructions.
1617
1618 // Convert LLVM intrinsics calls to MLIR intrinsics.
1619 if (auto *intrinsic = dyn_cast<llvm::IntrinsicInst>(Val: inst))
1620 return convertIntrinsic(inst: intrinsic);
1621
1622 // Convert all remaining LLVM instructions to MLIR operations.
1623 return convertInstruction(inst);
1624}
1625
1626FlatSymbolRefAttr ModuleImport::getPersonalityAsAttr(llvm::Function *f) {
1627 if (!f->hasPersonalityFn())
1628 return nullptr;
1629
1630 llvm::Constant *pf = f->getPersonalityFn();
1631
1632 // If it directly has a name, we can use it.
1633 if (pf->hasName())
1634 return SymbolRefAttr::get(builder.getContext(), pf->getName());
1635
1636 // If it doesn't have a name, currently, only function pointers that are
1637 // bitcast to i8* are parsed.
1638 if (auto *ce = dyn_cast<llvm::ConstantExpr>(Val: pf)) {
1639 if (ce->getOpcode() == llvm::Instruction::BitCast &&
1640 ce->getType() == llvm::PointerType::getUnqual(C&: f->getContext())) {
1641 if (auto *func = dyn_cast<llvm::Function>(ce->getOperand(0)))
1642 return SymbolRefAttr::get(builder.getContext(), func->getName());
1643 }
1644 }
1645 return FlatSymbolRefAttr();
1646}
1647
1648static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
1649 llvm::MemoryEffects memEffects = func->getMemoryEffects();
1650
1651 auto othermem = convertModRefInfoFromLLVM(
1652 memEffects.getModRef(Loc: llvm::MemoryEffects::Location::Other));
1653 auto argMem = convertModRefInfoFromLLVM(
1654 memEffects.getModRef(Loc: llvm::MemoryEffects::Location::ArgMem));
1655 auto inaccessibleMem = convertModRefInfoFromLLVM(
1656 memEffects.getModRef(Loc: llvm::MemoryEffects::Location::InaccessibleMem));
1657 auto memAttr = MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem,
1658 inaccessibleMem);
1659 // Only set the attr when it does not match the default value.
1660 if (memAttr.isReadWrite())
1661 return;
1662 funcOp.setMemoryAttr(memAttr);
1663}
1664
1665// List of LLVM IR attributes that map to an explicit attribute on the MLIR
1666// LLVMFuncOp.
1667static constexpr std::array ExplicitAttributes{
1668 StringLiteral("aarch64_pstate_sm_enabled"),
1669 StringLiteral("aarch64_pstate_sm_body"),
1670 StringLiteral("aarch64_pstate_sm_compatible"),
1671 StringLiteral("aarch64_new_za"),
1672 StringLiteral("aarch64_preserves_za"),
1673 StringLiteral("aarch64_in_za"),
1674 StringLiteral("aarch64_out_za"),
1675 StringLiteral("aarch64_inout_za"),
1676 StringLiteral("vscale_range"),
1677 StringLiteral("frame-pointer"),
1678 StringLiteral("target-features"),
1679 StringLiteral("unsafe-fp-math"),
1680 StringLiteral("no-infs-fp-math"),
1681 StringLiteral("no-nans-fp-math"),
1682 StringLiteral("approx-func-fp-math"),
1683 StringLiteral("no-signed-zeros-fp-math"),
1684};
1685
1686static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
1687 MLIRContext *context = funcOp.getContext();
1688 SmallVector<Attribute> passthroughs;
1689 llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes(
1690 Index: llvm::AttributeList::AttrIndex::FunctionIndex);
1691 for (llvm::Attribute attr : funcAttrs) {
1692 // Skip the memory attribute since the LLVMFuncOp has an explicit memory
1693 // attribute.
1694 if (attr.hasAttribute(llvm::Attribute::Memory))
1695 continue;
1696
1697 // Skip invalid type attributes.
1698 if (attr.isTypeAttribute()) {
1699 emitWarning(funcOp.getLoc(),
1700 "type attributes on a function are invalid, skipping it");
1701 continue;
1702 }
1703
1704 StringRef attrName;
1705 if (attr.isStringAttribute())
1706 attrName = attr.getKindAsString();
1707 else
1708 attrName = llvm::Attribute::getNameFromAttrKind(AttrKind: attr.getKindAsEnum());
1709 auto keyAttr = StringAttr::get(context, attrName);
1710
1711 // Skip attributes that map to an explicit attribute on the LLVMFuncOp.
1712 if (llvm::is_contained(Range: ExplicitAttributes, Element: attrName))
1713 continue;
1714
1715 if (attr.isStringAttribute()) {
1716 StringRef val = attr.getValueAsString();
1717 if (val.empty()) {
1718 passthroughs.push_back(Elt: keyAttr);
1719 continue;
1720 }
1721 passthroughs.push_back(
1722 ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
1723 continue;
1724 }
1725 if (attr.isIntAttribute()) {
1726 auto val = std::to_string(val: attr.getValueAsInt());
1727 passthroughs.push_back(
1728 ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
1729 continue;
1730 }
1731 if (attr.isEnumAttribute()) {
1732 passthroughs.push_back(Elt: keyAttr);
1733 continue;
1734 }
1735
1736 llvm_unreachable("unexpected attribute kind");
1737 }
1738
1739 if (!passthroughs.empty())
1740 funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs));
1741}
1742
1743void ModuleImport::processFunctionAttributes(llvm::Function *func,
1744 LLVMFuncOp funcOp) {
1745 processMemoryEffects(func, funcOp);
1746 processPassthroughAttrs(func, funcOp);
1747
1748 if (func->hasFnAttribute(Kind: "aarch64_pstate_sm_enabled"))
1749 funcOp.setArmStreaming(true);
1750 else if (func->hasFnAttribute(Kind: "aarch64_pstate_sm_body"))
1751 funcOp.setArmLocallyStreaming(true);
1752 else if (func->hasFnAttribute(Kind: "aarch64_pstate_sm_compatible"))
1753 funcOp.setArmStreamingCompatible(true);
1754
1755 if (func->hasFnAttribute(Kind: "aarch64_new_za"))
1756 funcOp.setArmNewZa(true);
1757 else if (func->hasFnAttribute(Kind: "aarch64_in_za"))
1758 funcOp.setArmInZa(true);
1759 else if (func->hasFnAttribute(Kind: "aarch64_out_za"))
1760 funcOp.setArmOutZa(true);
1761 else if (func->hasFnAttribute(Kind: "aarch64_inout_za"))
1762 funcOp.setArmInoutZa(true);
1763 else if (func->hasFnAttribute(Kind: "aarch64_preserves_za"))
1764 funcOp.setArmPreservesZa(true);
1765
1766 llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
1767 if (attr.isValid()) {
1768 MLIRContext *context = funcOp.getContext();
1769 auto intTy = IntegerType::get(context, 32);
1770 funcOp.setVscaleRangeAttr(LLVM::VScaleRangeAttr::get(
1771 context, IntegerAttr::get(intTy, attr.getVScaleRangeMin()),
1772 IntegerAttr::get(intTy, attr.getVScaleRangeMax().value_or(0))));
1773 }
1774
1775 // Process frame-pointer attribute.
1776 if (func->hasFnAttribute(Kind: "frame-pointer")) {
1777 StringRef stringRefFramePointerKind =
1778 func->getFnAttribute(Kind: "frame-pointer").getValueAsString();
1779 funcOp.setFramePointerAttr(LLVM::FramePointerKindAttr::get(
1780 funcOp.getContext(), LLVM::framePointerKind::symbolizeFramePointerKind(
1781 stringRefFramePointerKind)
1782 .value()));
1783 }
1784
1785 if (llvm::Attribute attr = func->getFnAttribute("target-cpu");
1786 attr.isStringAttribute())
1787 funcOp.setTargetCpuAttr(StringAttr::get(context, attr.getValueAsString()));
1788
1789 if (llvm::Attribute attr = func->getFnAttribute("target-features");
1790 attr.isStringAttribute())
1791 funcOp.setTargetFeaturesAttr(
1792 LLVM::TargetFeaturesAttr::get(context, attr.getValueAsString()));
1793
1794 if (llvm::Attribute attr = func->getFnAttribute(Kind: "unsafe-fp-math");
1795 attr.isStringAttribute())
1796 funcOp.setUnsafeFpMath(attr.getValueAsBool());
1797
1798 if (llvm::Attribute attr = func->getFnAttribute(Kind: "no-infs-fp-math");
1799 attr.isStringAttribute())
1800 funcOp.setNoInfsFpMath(attr.getValueAsBool());
1801
1802 if (llvm::Attribute attr = func->getFnAttribute(Kind: "no-nans-fp-math");
1803 attr.isStringAttribute())
1804 funcOp.setNoNansFpMath(attr.getValueAsBool());
1805
1806 if (llvm::Attribute attr = func->getFnAttribute(Kind: "approx-func-fp-math");
1807 attr.isStringAttribute())
1808 funcOp.setApproxFuncFpMath(attr.getValueAsBool());
1809
1810 if (llvm::Attribute attr = func->getFnAttribute(Kind: "no-signed-zeros-fp-math");
1811 attr.isStringAttribute())
1812 funcOp.setNoSignedZerosFpMath(attr.getValueAsBool());
1813}
1814
1815DictionaryAttr
1816ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
1817 OpBuilder &builder) {
1818 SmallVector<NamedAttribute> paramAttrs;
1819 for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
1820 auto llvmAttr = llvmParamAttrs.getAttribute(Kind: llvmKind);
1821 // Skip attributes that are not attached.
1822 if (!llvmAttr.isValid())
1823 continue;
1824 Attribute mlirAttr;
1825 if (llvmAttr.isTypeAttribute())
1826 mlirAttr = TypeAttr::get(convertType(llvmAttr.getValueAsType()));
1827 else if (llvmAttr.isIntAttribute())
1828 mlirAttr = builder.getI64IntegerAttr(llvmAttr.getValueAsInt());
1829 else if (llvmAttr.isEnumAttribute())
1830 mlirAttr = builder.getUnitAttr();
1831 else
1832 llvm_unreachable("unexpected parameter attribute kind");
1833 paramAttrs.push_back(Elt: builder.getNamedAttr(name: mlirName, val: mlirAttr));
1834 }
1835
1836 return builder.getDictionaryAttr(paramAttrs);
1837}
1838
1839void ModuleImport::convertParameterAttributes(llvm::Function *func,
1840 LLVMFuncOp funcOp,
1841 OpBuilder &builder) {
1842 auto llvmAttrs = func->getAttributes();
1843 for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
1844 llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(ArgNo: i);
1845 funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder));
1846 }
1847 // Convert the result attributes and attach them wrapped in an ArrayAttribute
1848 // to the funcOp.
1849 llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
1850 if (!llvmResAttr.hasAttributes())
1851 return;
1852 funcOp.setResAttrsAttr(
1853 builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
1854}
1855
1856LogicalResult ModuleImport::processFunction(llvm::Function *func) {
1857 clearRegionState();
1858
1859 auto functionType =
1860 dyn_cast<LLVMFunctionType>(convertType(func->getFunctionType()));
1861 if (func->isIntrinsic() &&
1862 iface.isConvertibleIntrinsic(id: func->getIntrinsicID()))
1863 return success();
1864
1865 bool dsoLocal = func->hasLocalLinkage();
1866 CConv cconv = convertCConvFromLLVM(func->getCallingConv());
1867
1868 // Insert the function at the end of the module.
1869 OpBuilder::InsertionGuard guard(builder);
1870 builder.setInsertionPoint(mlirModule.getBody(), mlirModule.getBody()->end());
1871
1872 Location loc = debugImporter->translateFuncLocation(func);
1873 LLVMFuncOp funcOp = builder.create<LLVMFuncOp>(
1874 loc, func->getName(), functionType,
1875 convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv);
1876
1877 convertParameterAttributes(func, funcOp, builder);
1878
1879 if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func))
1880 funcOp.setPersonalityAttr(personality);
1881 else if (func->hasPersonalityFn())
1882 emitWarning(funcOp.getLoc(), "could not deduce personality, skipping it");
1883
1884 if (func->hasGC())
1885 funcOp.setGarbageCollector(StringRef(func->getGC()));
1886
1887 if (func->hasAtLeastLocalUnnamedAddr())
1888 funcOp.setUnnamedAddr(convertUnnamedAddrFromLLVM(func->getUnnamedAddr()));
1889
1890 if (func->hasSection())
1891 funcOp.setSection(StringRef(func->getSection()));
1892
1893 funcOp.setVisibility_(convertVisibilityFromLLVM(func->getVisibility()));
1894
1895 if (func->hasComdat())
1896 funcOp.setComdatAttr(comdatMapping.lookup(func->getComdat()));
1897
1898 if (llvm::MaybeAlign maybeAlign = func->getAlign())
1899 funcOp.setAlignment(maybeAlign->value());
1900
1901 // Handle Function attributes.
1902 processFunctionAttributes(func, funcOp);
1903
1904 // Convert non-debug metadata by using the dialect interface.
1905 SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
1906 func->getAllMetadata(MDs&: allMetadata);
1907 for (auto &[kind, node] : allMetadata) {
1908 if (!iface.isConvertibleMetadata(kind))
1909 continue;
1910 if (failed(iface.setMetadataAttrs(builder, kind, node, op: funcOp, moduleImport&: *this))) {
1911 emitWarning(funcOp.getLoc())
1912 << "unhandled function metadata: " << diagMD(node, module: llvmModule.get())
1913 << " on " << diag(value: *func);
1914 }
1915 }
1916
1917 if (func->isDeclaration())
1918 return success();
1919
1920 // Collect the set of basic blocks reachable from the function's entry block.
1921 // This step is crucial as LLVM IR can contain unreachable blocks that
1922 // self-dominate. As a result, an operation might utilize a variable it
1923 // defines, which the import does not support. Given that MLIR lacks block
1924 // label support, we can safely remove unreachable blocks, as there are no
1925 // indirect branch instructions that could potentially target these blocks.
1926 llvm::df_iterator_default_set<llvm::BasicBlock *> reachable;
1927 for (llvm::BasicBlock *basicBlock : llvm::depth_first_ext(G: func, S&: reachable))
1928 (void)basicBlock;
1929
1930 // Eagerly create all reachable blocks.
1931 SmallVector<llvm::BasicBlock *> reachableBasicBlocks;
1932 for (llvm::BasicBlock &basicBlock : *func) {
1933 // Skip unreachable blocks.
1934 if (!reachable.contains(Ptr: &basicBlock))
1935 continue;
1936 Region &body = funcOp.getBody();
1937 Block *block = builder.createBlock(parent: &body, insertPt: body.end());
1938 mapBlock(llvm: &basicBlock, mlir: block);
1939 reachableBasicBlocks.push_back(&basicBlock);
1940 }
1941
1942 // Add function arguments to the entry block.
1943 for (const auto &it : llvm::enumerate(First: func->args())) {
1944 BlockArgument blockArg = funcOp.getFunctionBody().addArgument(
1945 functionType.getParamType(it.index()), funcOp.getLoc());
1946 mapValue(llvm: &it.value(), mlir: blockArg);
1947 }
1948
1949 // Process the blocks in topological order. The ordered traversal ensures
1950 // operands defined in a dominating block have a valid mapping to an MLIR
1951 // value once a block is translated.
1952 SetVector<llvm::BasicBlock *> blocks =
1953 getTopologicallySortedBlocks(reachableBasicBlocks);
1954 setConstantInsertionPointToStart(lookupBlock(blocks.front()));
1955 for (llvm::BasicBlock *basicBlock : blocks)
1956 if (failed(processBasicBlock(basicBlock, lookupBlock(basicBlock))))
1957 return failure();
1958
1959 // Process the debug intrinsics that require a delayed conversion after
1960 // everything else was converted.
1961 if (failed(result: processDebugIntrinsics()))
1962 return failure();
1963
1964 return success();
1965}
1966
1967/// Checks if `dbgIntr` is a kill location that holds metadata instead of an SSA
1968/// value.
1969static bool isMetadataKillLocation(llvm::DbgVariableIntrinsic *dbgIntr) {
1970 if (!dbgIntr->isKillLocation())
1971 return false;
1972 llvm::Value *value = dbgIntr->getArgOperand(i: 0);
1973 auto *nodeAsVal = dyn_cast<llvm::MetadataAsValue>(Val: value);
1974 if (!nodeAsVal)
1975 return false;
1976 return !isa<llvm::ValueAsMetadata>(Val: nodeAsVal->getMetadata());
1977}
1978
1979LogicalResult
1980ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr,
1981 DominanceInfo &domInfo) {
1982 Location loc = translateLoc(loc: dbgIntr->getDebugLoc());
1983 auto emitUnsupportedWarning = [&]() {
1984 if (emitExpensiveWarnings)
1985 emitWarning(loc) << "dropped intrinsic: " << diag(value: *dbgIntr);
1986 return success();
1987 };
1988 // Drop debug intrinsics with arg lists.
1989 // TODO: Support debug intrinsics that have arg lists.
1990 if (dbgIntr->hasArgList())
1991 return emitUnsupportedWarning();
1992 // Kill locations can have metadata nodes as location operand. This
1993 // cannot be converted to poison as the type cannot be reconstructed.
1994 // TODO: find a way to support this case.
1995 if (isMetadataKillLocation(dbgIntr))
1996 return emitUnsupportedWarning();
1997 // Drop debug intrinsics if the associated variable information cannot be
1998 // translated due to cyclic debug metadata.
1999 // TODO: Support cyclic debug metadata.
2000 DILocalVariableAttr localVariableAttr =
2001 matchLocalVariableAttr(dbgIntr->getArgOperand(1));
2002 if (!localVariableAttr)
2003 return emitUnsupportedWarning();
2004 FailureOr<Value> argOperand = convertMetadataValue(value: dbgIntr->getArgOperand(i: 0));
2005 if (failed(result: argOperand))
2006 return emitError(loc) << "failed to convert a debug intrinsic operand: "
2007 << diag(value: *dbgIntr);
2008
2009 // Ensure that the debug instrinsic is inserted right after its operand is
2010 // defined. Otherwise, the operand might not necessarily dominate the
2011 // intrinsic. If the defining operation is a terminator, insert the intrinsic
2012 // into a dominated block.
2013 OpBuilder::InsertionGuard guard(builder);
2014 if (Operation *op = argOperand->getDefiningOp();
2015 op && op->hasTrait<OpTrait::IsTerminator>()) {
2016 // Find a dominated block that can hold the debug intrinsic.
2017 auto dominatedBlocks = domInfo.getNode(a: op->getBlock())->children();
2018 // If no block is dominated by the terminator, this intrinisc cannot be
2019 // converted.
2020 if (dominatedBlocks.empty())
2021 return emitUnsupportedWarning();
2022 // Set insertion point before the terminator, to avoid inserting something
2023 // before landingpads.
2024 Block *dominatedBlock = (*dominatedBlocks.begin())->getBlock();
2025 builder.setInsertionPoint(dominatedBlock->getTerminator());
2026 } else {
2027 builder.setInsertionPointAfterValue(*argOperand);
2028 }
2029 auto locationExprAttr =
2030 debugImporter->translateExpression(dbgIntr->getExpression());
2031 Operation *op =
2032 llvm::TypeSwitch<llvm::DbgVariableIntrinsic *, Operation *>(dbgIntr)
2033 .Case(caseFn: [&](llvm::DbgDeclareInst *) {
2034 return builder.create<LLVM::DbgDeclareOp>(
2035 loc, *argOperand, localVariableAttr, locationExprAttr);
2036 })
2037 .Case(caseFn: [&](llvm::DbgValueInst *) {
2038 return builder.create<LLVM::DbgValueOp>(
2039 loc, *argOperand, localVariableAttr, locationExprAttr);
2040 });
2041 mapNoResultOp(llvm: dbgIntr, mlir: op);
2042 setNonDebugMetadataAttrs(inst: dbgIntr, op);
2043 return success();
2044}
2045
2046LogicalResult ModuleImport::processDebugIntrinsics() {
2047 DominanceInfo domInfo;
2048 for (llvm::Instruction *inst : debugIntrinsics) {
2049 auto *intrCall = cast<llvm::DbgVariableIntrinsic>(Val: inst);
2050 if (failed(result: processDebugIntrinsic(dbgIntr: intrCall, domInfo)))
2051 return failure();
2052 }
2053 return success();
2054}
2055
2056LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
2057 Block *block) {
2058 builder.setInsertionPointToStart(block);
2059 for (llvm::Instruction &inst : *bb) {
2060 if (failed(result: processInstruction(inst: &inst)))
2061 return failure();
2062
2063 // Skip additional processing when the instructions is a debug intrinsics
2064 // that was not yet converted.
2065 if (debugIntrinsics.contains(key: &inst))
2066 continue;
2067
2068 // Set the non-debug metadata attributes on the imported operation and emit
2069 // a warning if an instruction other than a phi instruction is dropped
2070 // during the import.
2071 if (Operation *op = lookupOperation(inst: &inst)) {
2072 setNonDebugMetadataAttrs(inst: &inst, op);
2073 } else if (inst.getOpcode() != llvm::Instruction::PHI) {
2074 if (emitExpensiveWarnings) {
2075 Location loc = debugImporter->translateLoc(loc: inst.getDebugLoc());
2076 emitWarning(loc) << "dropped instruction: " << diag(value: inst);
2077 }
2078 }
2079 }
2080 return success();
2081}
2082
2083FailureOr<SmallVector<AccessGroupAttr>>
2084ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
2085 return loopAnnotationImporter->lookupAccessGroupAttrs(node);
2086}
2087
2088LoopAnnotationAttr
2089ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
2090 Location loc) const {
2091 return loopAnnotationImporter->translateLoopAnnotation(node, loc);
2092}
2093
2094OwningOpRef<ModuleOp>
2095mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
2096 MLIRContext *context, bool emitExpensiveWarnings,
2097 bool dropDICompositeTypeElements) {
2098 // Preload all registered dialects to allow the import to iterate the
2099 // registered LLVMImportDialectInterface implementations and query the
2100 // supported LLVM IR constructs before starting the translation. Assumes the
2101 // LLVM and DLTI dialects that convert the core LLVM IR constructs have been
2102 // registered before.
2103 assert(llvm::is_contained(context->getAvailableDialects(),
2104 LLVMDialect::getDialectNamespace()));
2105 assert(llvm::is_contained(context->getAvailableDialects(),
2106 DLTIDialect::getDialectNamespace()));
2107 context->loadAllAvailableDialects();
2108 OwningOpRef<ModuleOp> module(ModuleOp::create(FileLineColLoc::get(
2109 StringAttr::get(context, llvmModule->getSourceFileName()), /*line=*/0,
2110 /*column=*/0)));
2111
2112 ModuleImport moduleImport(module.get(), std::move(llvmModule),
2113 emitExpensiveWarnings, dropDICompositeTypeElements);
2114 if (failed(result: moduleImport.initializeImportInterface()))
2115 return {};
2116 if (failed(result: moduleImport.convertDataLayout()))
2117 return {};
2118 if (failed(result: moduleImport.convertComdats()))
2119 return {};
2120 if (failed(result: moduleImport.convertMetadata()))
2121 return {};
2122 if (failed(result: moduleImport.convertGlobals()))
2123 return {};
2124 if (failed(result: moduleImport.convertFunctions()))
2125 return {};
2126
2127 return module;
2128}
2129

source code of mlir/lib/Target/LLVMIR/ModuleImport.cpp