1 | //===- ModuleImport.cpp - LLVM to MLIR conversion ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the import of an LLVM IR module into an LLVM dialect |
10 | // module. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Target/LLVMIR/ModuleImport.h" |
15 | #include "mlir/Target/LLVMIR/Import.h" |
16 | |
17 | #include "AttrKindDetail.h" |
18 | #include "DataLayoutImporter.h" |
19 | #include "DebugImporter.h" |
20 | #include "LoopAnnotationImporter.h" |
21 | |
22 | #include "mlir/Dialect/DLTI/DLTI.h" |
23 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
24 | #include "mlir/IR/Builders.h" |
25 | #include "mlir/IR/Matchers.h" |
26 | #include "mlir/Interfaces/DataLayoutInterfaces.h" |
27 | #include "mlir/Tools/mlir-translate/Translation.h" |
28 | |
29 | #include "llvm/ADT/DepthFirstIterator.h" |
30 | #include "llvm/ADT/PostOrderIterator.h" |
31 | #include "llvm/ADT/ScopeExit.h" |
32 | #include "llvm/ADT/StringSet.h" |
33 | #include "llvm/ADT/TypeSwitch.h" |
34 | #include "llvm/IR/Comdat.h" |
35 | #include "llvm/IR/Constants.h" |
36 | #include "llvm/IR/InlineAsm.h" |
37 | #include "llvm/IR/InstIterator.h" |
38 | #include "llvm/IR/Instructions.h" |
39 | #include "llvm/IR/IntrinsicInst.h" |
40 | #include "llvm/IR/Metadata.h" |
41 | #include "llvm/IR/Operator.h" |
42 | #include "llvm/Support/ModRef.h" |
43 | |
44 | using namespace mlir; |
45 | using namespace mlir::LLVM; |
46 | using namespace mlir::LLVM::detail; |
47 | |
48 | #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc" |
49 | |
50 | // Utility to print an LLVM value as a string for passing to emitError(). |
51 | // FIXME: Diagnostic should be able to natively handle types that have |
52 | // operator << (raw_ostream&) defined. |
53 | static std::string diag(const llvm::Value &value) { |
54 | std::string str; |
55 | llvm::raw_string_ostream os(str); |
56 | os << value; |
57 | return os.str(); |
58 | } |
59 | |
60 | // Utility to print an LLVM metadata node as a string for passing |
61 | // to emitError(). The module argument is needed to print the nodes |
62 | // canonically numbered. |
63 | static std::string diagMD(const llvm::Metadata *node, |
64 | const llvm::Module *module) { |
65 | std::string str; |
66 | llvm::raw_string_ostream os(str); |
67 | node->print(OS&: os, M: module, /*IsForDebug=*/true); |
68 | return os.str(); |
69 | } |
70 | |
71 | /// Returns the name of the global_ctors global variables. |
72 | static constexpr StringRef getGlobalCtorsVarName() { |
73 | return "llvm.global_ctors" ; |
74 | } |
75 | |
76 | /// Returns the name of the global_dtors global variables. |
77 | static constexpr StringRef getGlobalDtorsVarName() { |
78 | return "llvm.global_dtors" ; |
79 | } |
80 | |
81 | /// Returns the symbol name for the module-level comdat operation. It must not |
82 | /// conflict with the user namespace. |
83 | static constexpr StringRef getGlobalComdatOpName() { |
84 | return "__llvm_global_comdat" ; |
85 | } |
86 | |
87 | /// Converts the sync scope identifier of `inst` to the string representation |
88 | /// necessary to build an atomic LLVM dialect operation. Returns the empty |
89 | /// string if the operation has either no sync scope or the default system-level |
90 | /// sync scope attached. The atomic operations only set their sync scope |
91 | /// attribute if they have a non-default sync scope attached. |
92 | static StringRef getLLVMSyncScope(llvm::Instruction *inst) { |
93 | std::optional<llvm::SyncScope::ID> syncScopeID = |
94 | llvm::getAtomicSyncScopeID(I: inst); |
95 | if (!syncScopeID) |
96 | return "" ; |
97 | |
98 | // Search the sync scope name for the given identifier. The default |
99 | // system-level sync scope thereby maps to the empty string. |
100 | SmallVector<StringRef> syncScopeName; |
101 | llvm::LLVMContext &llvmContext = inst->getContext(); |
102 | llvmContext.getSyncScopeNames(SSNs&: syncScopeName); |
103 | auto *it = llvm::find_if(Range&: syncScopeName, P: [&](StringRef name) { |
104 | return *syncScopeID == llvmContext.getOrInsertSyncScopeID(SSN: name); |
105 | }); |
106 | if (it != syncScopeName.end()) |
107 | return *it; |
108 | llvm_unreachable("incorrect sync scope identifier" ); |
109 | } |
110 | |
111 | /// Converts an array of unsigned indices to a signed integer position array. |
112 | static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) { |
113 | SmallVector<int64_t> position; |
114 | llvm::append_range(C&: position, R&: indices); |
115 | return position; |
116 | } |
117 | |
118 | /// Converts the LLVM instructions that have a generated MLIR builder. Using a |
119 | /// static implementation method called from the module import ensures the |
120 | /// builders have to use the `moduleImport` argument and cannot directly call |
121 | /// import methods. As a result, both the intrinsic and the instruction MLIR |
122 | /// builders have to use the `moduleImport` argument and none of them has direct |
123 | /// access to the private module import methods. |
124 | static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder, |
125 | llvm::Instruction *inst, |
126 | ModuleImport &moduleImport, |
127 | LLVMImportInterface &iface) { |
128 | // Copy the operands to an LLVM operands array reference for conversion. |
129 | SmallVector<llvm::Value *> operands(inst->operands()); |
130 | ArrayRef<llvm::Value *> llvmOperands(operands); |
131 | |
132 | // Convert all instructions that provide an MLIR builder. |
133 | if (iface.isConvertibleInstruction(id: inst->getOpcode())) |
134 | return iface.convertInstruction(builder&: odsBuilder, inst, llvmOperands, |
135 | moduleImport); |
136 | // TODO: Implement the `convertInstruction` hooks in the |
137 | // `LLVMDialectLLVMIRImportInterface` and move the following include there. |
138 | #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc" |
139 | return failure(); |
140 | } |
141 | |
142 | /// Get a topologically sorted list of blocks for the given basic blocks. |
143 | static SetVector<llvm::BasicBlock *> |
144 | getTopologicallySortedBlocks(ArrayRef<llvm::BasicBlock *> basicBlocks) { |
145 | SetVector<llvm::BasicBlock *> blocks; |
146 | for (llvm::BasicBlock *basicBlock : basicBlocks) { |
147 | if (!blocks.contains(basicBlock)) { |
148 | llvm::ReversePostOrderTraversal<llvm::BasicBlock *> traversal(basicBlock); |
149 | blocks.insert(traversal.begin(), traversal.end()); |
150 | } |
151 | } |
152 | assert(blocks.size() == basicBlocks.size() && "some blocks are not sorted" ); |
153 | return blocks; |
154 | } |
155 | |
156 | ModuleImport::ModuleImport(ModuleOp mlirModule, |
157 | std::unique_ptr<llvm::Module> llvmModule, |
158 | bool emitExpensiveWarnings, |
159 | bool importEmptyDICompositeTypes) |
160 | : builder(mlirModule->getContext()), context(mlirModule->getContext()), |
161 | mlirModule(mlirModule), llvmModule(std::move(llvmModule)), |
162 | iface(mlirModule->getContext()), |
163 | typeTranslator(*mlirModule->getContext()), |
164 | debugImporter(std::make_unique<DebugImporter>( |
165 | mlirModule, importEmptyDICompositeTypes)), |
166 | loopAnnotationImporter( |
167 | std::make_unique<LoopAnnotationImporter>(args&: *this, args&: builder)), |
168 | emitExpensiveWarnings(emitExpensiveWarnings) { |
169 | builder.setInsertionPointToStart(mlirModule.getBody()); |
170 | } |
171 | |
172 | ComdatOp ModuleImport::getGlobalComdatOp() { |
173 | if (globalComdatOp) |
174 | return globalComdatOp; |
175 | |
176 | OpBuilder::InsertionGuard guard(builder); |
177 | builder.setInsertionPointToEnd(mlirModule.getBody()); |
178 | globalComdatOp = |
179 | builder.create<ComdatOp>(mlirModule.getLoc(), getGlobalComdatOpName()); |
180 | globalInsertionOp = globalComdatOp; |
181 | return globalComdatOp; |
182 | } |
183 | |
184 | LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) { |
185 | Location loc = mlirModule.getLoc(); |
186 | |
187 | // If `node` is a valid TBAA root node, then return its optional identity |
188 | // string, otherwise return failure. |
189 | auto getIdentityIfRootNode = |
190 | [&](const llvm::MDNode *node) -> FailureOr<std::optional<StringRef>> { |
191 | // Root node, e.g.: |
192 | // !0 = !{!"Simple C/C++ TBAA"} |
193 | // !1 = !{} |
194 | if (node->getNumOperands() > 1) |
195 | return failure(); |
196 | // If the operand is MDString, then assume that this is a root node. |
197 | if (node->getNumOperands() == 1) |
198 | if (const auto *op0 = dyn_cast<const llvm::MDString>(Val: node->getOperand(I: 0))) |
199 | return std::optional<StringRef>{op0->getString()}; |
200 | return std::optional<StringRef>{}; |
201 | }; |
202 | |
203 | // If `node` looks like a TBAA type descriptor metadata, |
204 | // then return true, if it is a valid node, and false otherwise. |
205 | // If it does not look like a TBAA type descriptor metadata, then |
206 | // return std::nullopt. |
207 | // If `identity` and `memberTypes/Offsets` are non-null, then they will |
208 | // contain the converted metadata operands for a valid TBAA node (i.e. when |
209 | // true is returned). |
210 | auto isTypeDescriptorNode = [&](const llvm::MDNode *node, |
211 | StringRef *identity = nullptr, |
212 | SmallVectorImpl<TBAAMemberAttr> *members = |
213 | nullptr) -> std::optional<bool> { |
214 | unsigned numOperands = node->getNumOperands(); |
215 | // Type descriptor, e.g.: |
216 | // !1 = !{!"int", !0, /*optional*/i64 0} /* scalar int type */ |
217 | // !2 = !{!"agg_t", !1, i64 0} /* struct agg_t { int x; } */ |
218 | if (numOperands < 2) |
219 | return std::nullopt; |
220 | |
221 | // TODO: support "new" format (D41501) for type descriptors, |
222 | // where the first operand is an MDNode. |
223 | const auto *identityNode = |
224 | dyn_cast<const llvm::MDString>(Val: node->getOperand(I: 0)); |
225 | if (!identityNode) |
226 | return std::nullopt; |
227 | |
228 | // This should be a type descriptor node. |
229 | if (identity) |
230 | *identity = identityNode->getString(); |
231 | |
232 | for (unsigned pairNum = 0, e = numOperands / 2; pairNum < e; ++pairNum) { |
233 | const auto *memberNode = |
234 | dyn_cast<const llvm::MDNode>(Val: node->getOperand(I: 2 * pairNum + 1)); |
235 | if (!memberNode) { |
236 | emitError(loc) << "operand '" << 2 * pairNum + 1 << "' must be MDNode: " |
237 | << diagMD(node, module: llvmModule.get()); |
238 | return false; |
239 | } |
240 | int64_t offset = 0; |
241 | if (2 * pairNum + 2 >= numOperands) { |
242 | // Allow for optional 0 offset in 2-operand nodes. |
243 | if (numOperands != 2) { |
244 | emitError(loc) << "missing member offset: " |
245 | << diagMD(node, module: llvmModule.get()); |
246 | return false; |
247 | } |
248 | } else { |
249 | auto *offsetCI = llvm::mdconst::dyn_extract<llvm::ConstantInt>( |
250 | MD: node->getOperand(I: 2 * pairNum + 2)); |
251 | if (!offsetCI) { |
252 | emitError(loc) << "operand '" << 2 * pairNum + 2 |
253 | << "' must be ConstantInt: " |
254 | << diagMD(node, module: llvmModule.get()); |
255 | return false; |
256 | } |
257 | offset = offsetCI->getZExtValue(); |
258 | } |
259 | |
260 | if (members) |
261 | members->push_back(TBAAMemberAttr::get( |
262 | cast<TBAANodeAttr>(tbaaMapping.lookup(memberNode)), offset)); |
263 | } |
264 | |
265 | return true; |
266 | }; |
267 | |
268 | // If `node` looks like a TBAA access tag metadata, |
269 | // then return true, if it is a valid node, and false otherwise. |
270 | // If it does not look like a TBAA access tag metadata, then |
271 | // return std::nullopt. |
272 | // If the other arguments are non-null, then they will contain |
273 | // the converted metadata operands for a valid TBAA node (i.e. when true is |
274 | // returned). |
275 | auto isTagNode = [&](const llvm::MDNode *node, |
276 | TBAATypeDescriptorAttr *baseAttr = nullptr, |
277 | TBAATypeDescriptorAttr *accessAttr = nullptr, |
278 | int64_t *offset = nullptr, |
279 | bool *isConstant = nullptr) -> std::optional<bool> { |
280 | // Access tag, e.g.: |
281 | // !3 = !{!1, !1, i64 0} /* scalar int access */ |
282 | // !4 = !{!2, !1, i64 0} /* agg_t::x access */ |
283 | // |
284 | // Optional 4th argument is ConstantInt 0/1 identifying whether |
285 | // the location being accessed is "constant" (see for details: |
286 | // https://llvm.org/docs/LangRef.html#representation). |
287 | unsigned numOperands = node->getNumOperands(); |
288 | if (numOperands != 3 && numOperands != 4) |
289 | return std::nullopt; |
290 | const auto *baseMD = dyn_cast<const llvm::MDNode>(Val: node->getOperand(I: 0)); |
291 | const auto *accessMD = dyn_cast<const llvm::MDNode>(Val: node->getOperand(I: 1)); |
292 | auto *offsetCI = |
293 | llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: node->getOperand(I: 2)); |
294 | if (!baseMD || !accessMD || !offsetCI) |
295 | return std::nullopt; |
296 | // TODO: support "new" TBAA format, if needed (see D41501). |
297 | // In the "old" format the first operand of the access type |
298 | // metadata is MDString. We have to distinguish the formats, |
299 | // because access tags have the same structure, but different |
300 | // meaning for the operands. |
301 | if (accessMD->getNumOperands() < 1 || |
302 | !isa<llvm::MDString>(Val: accessMD->getOperand(I: 0))) |
303 | return std::nullopt; |
304 | bool isConst = false; |
305 | if (numOperands == 4) { |
306 | auto *isConstantCI = |
307 | llvm::mdconst::dyn_extract<llvm::ConstantInt>(MD: node->getOperand(I: 3)); |
308 | if (!isConstantCI) { |
309 | emitError(loc) << "operand '3' must be ConstantInt: " |
310 | << diagMD(node, module: llvmModule.get()); |
311 | return false; |
312 | } |
313 | isConst = isConstantCI->getValue()[0]; |
314 | } |
315 | if (baseAttr) |
316 | *baseAttr = cast<TBAATypeDescriptorAttr>(tbaaMapping.lookup(baseMD)); |
317 | if (accessAttr) |
318 | *accessAttr = cast<TBAATypeDescriptorAttr>(tbaaMapping.lookup(accessMD)); |
319 | if (offset) |
320 | *offset = offsetCI->getZExtValue(); |
321 | if (isConstant) |
322 | *isConstant = isConst; |
323 | return true; |
324 | }; |
325 | |
326 | // Do a post-order walk over the TBAA Graph. Since a correct TBAA Graph is a |
327 | // DAG, a post-order walk guarantees that we convert any metadata node we |
328 | // depend on, prior to converting the current node. |
329 | DenseSet<const llvm::MDNode *> seen; |
330 | SmallVector<const llvm::MDNode *> workList; |
331 | workList.push_back(Elt: node); |
332 | while (!workList.empty()) { |
333 | const llvm::MDNode *current = workList.back(); |
334 | if (tbaaMapping.contains(Val: current)) { |
335 | // Already converted. Just pop from the worklist. |
336 | workList.pop_back(); |
337 | continue; |
338 | } |
339 | |
340 | // If any child of this node is not yet converted, don't pop the current |
341 | // node from the worklist but push the not-yet-converted children in the |
342 | // front of the worklist. |
343 | bool anyChildNotConverted = false; |
344 | for (const llvm::MDOperand &operand : current->operands()) |
345 | if (auto *childNode = dyn_cast_or_null<const llvm::MDNode>(Val: operand.get())) |
346 | if (!tbaaMapping.contains(Val: childNode)) { |
347 | workList.push_back(Elt: childNode); |
348 | anyChildNotConverted = true; |
349 | } |
350 | |
351 | if (anyChildNotConverted) { |
352 | // If this is the second time we failed to convert an element in the |
353 | // worklist it must be because a child is dependent on it being converted |
354 | // and we have a cycle in the graph. Cycles are not allowed in TBAA |
355 | // graphs. |
356 | if (!seen.insert(V: current).second) |
357 | return emitError(loc) << "has cycle in TBAA graph: " |
358 | << diagMD(node: current, module: llvmModule.get()); |
359 | |
360 | continue; |
361 | } |
362 | |
363 | // Otherwise simply import the current node. |
364 | workList.pop_back(); |
365 | |
366 | FailureOr<std::optional<StringRef>> rootNodeIdentity = |
367 | getIdentityIfRootNode(current); |
368 | if (succeeded(result: rootNodeIdentity)) { |
369 | StringAttr stringAttr = *rootNodeIdentity |
370 | ? builder.getStringAttr(**rootNodeIdentity) |
371 | : nullptr; |
372 | // The root nodes do not have operands, so we can create |
373 | // the TBAARootAttr on the first walk. |
374 | tbaaMapping.insert({current, builder.getAttr<TBAARootAttr>(stringAttr)}); |
375 | continue; |
376 | } |
377 | |
378 | StringRef identity; |
379 | SmallVector<TBAAMemberAttr> members; |
380 | if (std::optional<bool> isValid = |
381 | isTypeDescriptorNode(current, &identity, &members)) { |
382 | assert(isValid.value() && "type descriptor node must be valid" ); |
383 | |
384 | tbaaMapping.insert({current, builder.getAttr<TBAATypeDescriptorAttr>( |
385 | identity, members)}); |
386 | continue; |
387 | } |
388 | |
389 | TBAATypeDescriptorAttr baseAttr, accessAttr; |
390 | int64_t offset; |
391 | bool isConstant; |
392 | if (std::optional<bool> isValid = |
393 | isTagNode(current, &baseAttr, &accessAttr, &offset, &isConstant)) { |
394 | assert(isValid.value() && "access tag node must be valid" ); |
395 | tbaaMapping.insert( |
396 | {current, builder.getAttr<TBAATagAttr>(baseAttr, accessAttr, offset, |
397 | isConstant)}); |
398 | continue; |
399 | } |
400 | |
401 | return emitError(loc) << "unsupported TBAA node format: " |
402 | << diagMD(node: current, module: llvmModule.get()); |
403 | } |
404 | return success(); |
405 | } |
406 | |
407 | LogicalResult |
408 | ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) { |
409 | Location loc = mlirModule.getLoc(); |
410 | if (failed(result: loopAnnotationImporter->translateAccessGroup(node, loc))) |
411 | return emitError(loc) << "unsupported access group node: " |
412 | << diagMD(node, module: llvmModule.get()); |
413 | return success(); |
414 | } |
415 | |
416 | LogicalResult |
417 | ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) { |
418 | Location loc = mlirModule.getLoc(); |
419 | // Helper that verifies the node has a self reference operand. |
420 | auto verifySelfRef = [](const llvm::MDNode *node) { |
421 | return node->getNumOperands() != 0 && |
422 | node == dyn_cast<llvm::MDNode>(Val: node->getOperand(I: 0)); |
423 | }; |
424 | // Helper that verifies the given operand is a string or does not exist. |
425 | auto verifyDescription = [](const llvm::MDNode *node, unsigned idx) { |
426 | return idx >= node->getNumOperands() || |
427 | isa<llvm::MDString>(Val: node->getOperand(I: idx)); |
428 | }; |
429 | // Helper that creates an alias scope domain attribute. |
430 | auto createAliasScopeDomainOp = [&](const llvm::MDNode *aliasDomain) { |
431 | StringAttr description = nullptr; |
432 | if (aliasDomain->getNumOperands() >= 2) |
433 | if (auto *operand = dyn_cast<llvm::MDString>(Val: aliasDomain->getOperand(I: 1))) |
434 | description = builder.getStringAttr(operand->getString()); |
435 | return builder.getAttr<AliasScopeDomainAttr>( |
436 | DistinctAttr::create(builder.getUnitAttr()), description); |
437 | }; |
438 | |
439 | // Collect the alias scopes and domains to translate them. |
440 | for (const llvm::MDOperand &operand : node->operands()) { |
441 | if (const auto *scope = dyn_cast<llvm::MDNode>(Val: operand)) { |
442 | llvm::AliasScopeNode aliasScope(scope); |
443 | const llvm::MDNode *domain = aliasScope.getDomain(); |
444 | |
445 | // Verify the scope node points to valid scope metadata which includes |
446 | // verifying its domain. Perform the verification before looking it up in |
447 | // the alias scope mapping since it could have been inserted as a domain |
448 | // node before. |
449 | if (!verifySelfRef(scope) || !domain || !verifyDescription(scope, 2)) |
450 | return emitError(loc) << "unsupported alias scope node: " |
451 | << diagMD(node: scope, module: llvmModule.get()); |
452 | if (!verifySelfRef(domain) || !verifyDescription(domain, 1)) |
453 | return emitError(loc) << "unsupported alias domain node: " |
454 | << diagMD(node: domain, module: llvmModule.get()); |
455 | |
456 | if (aliasScopeMapping.contains(Val: scope)) |
457 | continue; |
458 | |
459 | // Convert the domain metadata node if it has not been translated before. |
460 | auto it = aliasScopeMapping.find(Val: aliasScope.getDomain()); |
461 | if (it == aliasScopeMapping.end()) { |
462 | auto aliasScopeDomainOp = createAliasScopeDomainOp(domain); |
463 | it = aliasScopeMapping.try_emplace(domain, aliasScopeDomainOp).first; |
464 | } |
465 | |
466 | // Convert the scope metadata node if it has not been converted before. |
467 | StringAttr description = nullptr; |
468 | if (!aliasScope.getName().empty()) |
469 | description = builder.getStringAttr(aliasScope.getName()); |
470 | auto aliasScopeOp = builder.getAttr<AliasScopeAttr>( |
471 | DistinctAttr::create(builder.getUnitAttr()), |
472 | cast<AliasScopeDomainAttr>(it->second), description); |
473 | aliasScopeMapping.try_emplace(aliasScope.getNode(), aliasScopeOp); |
474 | } |
475 | } |
476 | return success(); |
477 | } |
478 | |
479 | FailureOr<SmallVector<AliasScopeAttr>> |
480 | ModuleImport::lookupAliasScopeAttrs(const llvm::MDNode *node) const { |
481 | SmallVector<AliasScopeAttr> aliasScopes; |
482 | aliasScopes.reserve(node->getNumOperands()); |
483 | for (const llvm::MDOperand &operand : node->operands()) { |
484 | auto *node = cast<llvm::MDNode>(Val: operand.get()); |
485 | aliasScopes.push_back( |
486 | dyn_cast_or_null<AliasScopeAttr>(aliasScopeMapping.lookup(node))); |
487 | } |
488 | // Return failure if one of the alias scope lookups failed. |
489 | if (llvm::is_contained(aliasScopes, nullptr)) |
490 | return failure(); |
491 | return aliasScopes; |
492 | } |
493 | |
494 | void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) { |
495 | debugIntrinsics.insert(X: intrinsic); |
496 | } |
497 | |
498 | LogicalResult ModuleImport::convertLinkerOptionsMetadata() { |
499 | for (const llvm::NamedMDNode &named : llvmModule->named_metadata()) { |
500 | if (named.getName() != "llvm.linker.options" ) |
501 | continue; |
502 | // llvm.linker.options operands are lists of strings. |
503 | for (const llvm::MDNode *md : named.operands()) { |
504 | SmallVector<StringRef> options; |
505 | options.reserve(N: md->getNumOperands()); |
506 | for (const llvm::MDOperand &option : md->operands()) |
507 | options.push_back(Elt: cast<llvm::MDString>(Val: option)->getString()); |
508 | builder.create<LLVM::LinkerOptionsOp>(mlirModule.getLoc(), |
509 | builder.getStrArrayAttr(options)); |
510 | } |
511 | } |
512 | return success(); |
513 | } |
514 | |
515 | LogicalResult ModuleImport::convertMetadata() { |
516 | OpBuilder::InsertionGuard guard(builder); |
517 | builder.setInsertionPointToEnd(mlirModule.getBody()); |
518 | for (const llvm::Function &func : llvmModule->functions()) { |
519 | for (const llvm::Instruction &inst : llvm::instructions(F: func)) { |
520 | // Convert access group metadata nodes. |
521 | if (llvm::MDNode *node = |
522 | inst.getMetadata(KindID: llvm::LLVMContext::MD_access_group)) |
523 | if (failed(result: processAccessGroupMetadata(node))) |
524 | return failure(); |
525 | |
526 | // Convert alias analysis metadata nodes. |
527 | llvm::AAMDNodes aliasAnalysisNodes = inst.getAAMetadata(); |
528 | if (!aliasAnalysisNodes) |
529 | continue; |
530 | if (aliasAnalysisNodes.TBAA) |
531 | if (failed(result: processTBAAMetadata(node: aliasAnalysisNodes.TBAA))) |
532 | return failure(); |
533 | if (aliasAnalysisNodes.Scope) |
534 | if (failed(result: processAliasScopeMetadata(node: aliasAnalysisNodes.Scope))) |
535 | return failure(); |
536 | if (aliasAnalysisNodes.NoAlias) |
537 | if (failed(result: processAliasScopeMetadata(node: aliasAnalysisNodes.NoAlias))) |
538 | return failure(); |
539 | } |
540 | } |
541 | if (failed(result: convertLinkerOptionsMetadata())) |
542 | return failure(); |
543 | return success(); |
544 | } |
545 | |
546 | void ModuleImport::processComdat(const llvm::Comdat *comdat) { |
547 | if (comdatMapping.contains(Val: comdat)) |
548 | return; |
549 | |
550 | ComdatOp comdatOp = getGlobalComdatOp(); |
551 | OpBuilder::InsertionGuard guard(builder); |
552 | builder.setInsertionPointToEnd(&comdatOp.getBody().back()); |
553 | auto selectorOp = builder.create<ComdatSelectorOp>( |
554 | mlirModule.getLoc(), comdat->getName(), |
555 | convertComdatFromLLVM(comdat->getSelectionKind())); |
556 | auto symbolRef = |
557 | SymbolRefAttr::get(builder.getContext(), getGlobalComdatOpName(), |
558 | FlatSymbolRefAttr::get(selectorOp.getSymNameAttr())); |
559 | comdatMapping.try_emplace(comdat, symbolRef); |
560 | } |
561 | |
562 | LogicalResult ModuleImport::convertComdats() { |
563 | for (llvm::GlobalVariable &globalVar : llvmModule->globals()) |
564 | if (globalVar.hasComdat()) |
565 | processComdat(comdat: globalVar.getComdat()); |
566 | for (llvm::Function &func : llvmModule->functions()) |
567 | if (func.hasComdat()) |
568 | processComdat(comdat: func.getComdat()); |
569 | return success(); |
570 | } |
571 | |
572 | LogicalResult ModuleImport::convertGlobals() { |
573 | for (llvm::GlobalVariable &globalVar : llvmModule->globals()) { |
574 | if (globalVar.getName() == getGlobalCtorsVarName() || |
575 | globalVar.getName() == getGlobalDtorsVarName()) { |
576 | if (failed(result: convertGlobalCtorsAndDtors(globalVar: &globalVar))) { |
577 | return emitError(UnknownLoc::get(context)) |
578 | << "unhandled global variable: " << diag(globalVar); |
579 | } |
580 | continue; |
581 | } |
582 | if (failed(result: convertGlobal(globalVar: &globalVar))) { |
583 | return emitError(UnknownLoc::get(context)) |
584 | << "unhandled global variable: " << diag(globalVar); |
585 | } |
586 | } |
587 | return success(); |
588 | } |
589 | |
590 | LogicalResult ModuleImport::convertDataLayout() { |
591 | Location loc = mlirModule.getLoc(); |
592 | DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout()); |
593 | if (!dataLayoutImporter.getDataLayout()) |
594 | return emitError(loc, message: "cannot translate data layout: " ) |
595 | << dataLayoutImporter.getLastToken(); |
596 | |
597 | for (StringRef token : dataLayoutImporter.getUnhandledTokens()) |
598 | emitWarning(loc, message: "unhandled data layout token: " ) << token; |
599 | |
600 | mlirModule->setAttr(DLTIDialect::kDataLayoutAttrName, |
601 | dataLayoutImporter.getDataLayout()); |
602 | return success(); |
603 | } |
604 | |
605 | LogicalResult ModuleImport::convertFunctions() { |
606 | for (llvm::Function &func : llvmModule->functions()) |
607 | if (failed(result: processFunction(func: &func))) |
608 | return failure(); |
609 | return success(); |
610 | } |
611 | |
612 | void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst, |
613 | Operation *op) { |
614 | SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata; |
615 | inst->getAllMetadataOtherThanDebugLoc(MDs&: allMetadata); |
616 | for (auto &[kind, node] : allMetadata) { |
617 | if (!iface.isConvertibleMetadata(kind)) |
618 | continue; |
619 | if (failed(result: iface.setMetadataAttrs(builder, kind, node, op, moduleImport&: *this))) { |
620 | if (emitExpensiveWarnings) { |
621 | Location loc = debugImporter->translateLoc(loc: inst->getDebugLoc()); |
622 | emitWarning(loc) << "unhandled metadata: " |
623 | << diagMD(node, module: llvmModule.get()) << " on " |
624 | << diag(value: *inst); |
625 | } |
626 | } |
627 | } |
628 | } |
629 | |
630 | void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst, |
631 | Operation *op) const { |
632 | auto iface = cast<IntegerOverflowFlagsInterface>(op); |
633 | |
634 | IntegerOverflowFlags value = {}; |
635 | value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap()); |
636 | value = |
637 | bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap()); |
638 | |
639 | iface.setOverflowFlags(value); |
640 | } |
641 | |
642 | void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, |
643 | Operation *op) const { |
644 | auto iface = cast<FastmathFlagsInterface>(op); |
645 | |
646 | // Even if the imported operation implements the fastmath interface, the |
647 | // original instruction may not have fastmath flags set. Exit if an |
648 | // instruction, such as a non floating-point function call, does not have |
649 | // fastmath flags. |
650 | if (!isa<llvm::FPMathOperator>(Val: inst)) |
651 | return; |
652 | llvm::FastMathFlags flags = inst->getFastMathFlags(); |
653 | |
654 | // Set the fastmath bits flag-by-flag. |
655 | FastmathFlags value = {}; |
656 | value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs()); |
657 | value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs()); |
658 | value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros()); |
659 | value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal()); |
660 | value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract()); |
661 | value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc()); |
662 | value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc()); |
663 | FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value); |
664 | iface->setAttr(iface.getFastmathAttrName(), attr); |
665 | } |
666 | |
667 | /// Returns if `type` is a scalar integer or floating-point type. |
668 | static bool isScalarType(Type type) { |
669 | return isa<IntegerType, FloatType>(Val: type); |
670 | } |
671 | |
672 | /// Returns `type` if it is a builtin integer or floating-point vector type that |
673 | /// can be used to create an attribute or nullptr otherwise. If provided, |
674 | /// `arrayShape` is added to the shape of the vector to create an attribute that |
675 | /// matches an array of vectors. |
676 | static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) { |
677 | if (!LLVM::isCompatibleVectorType(type)) |
678 | return {}; |
679 | |
680 | llvm::ElementCount numElements = LLVM::getVectorNumElements(type); |
681 | if (numElements.isScalable()) { |
682 | emitError(UnknownLoc::get(type.getContext())) |
683 | << "scalable vectors not supported" ; |
684 | return {}; |
685 | } |
686 | |
687 | // An LLVM dialect vector can only contain scalars. |
688 | Type elementType = LLVM::getVectorElementType(type); |
689 | if (!isScalarType(type: elementType)) |
690 | return {}; |
691 | |
692 | SmallVector<int64_t> shape(arrayShape.begin(), arrayShape.end()); |
693 | shape.push_back(Elt: numElements.getKnownMinValue()); |
694 | return VectorType::get(shape, elementType); |
695 | } |
696 | |
697 | Type ModuleImport::getBuiltinTypeForAttr(Type type) { |
698 | if (!type) |
699 | return {}; |
700 | |
701 | // Return builtin integer and floating-point types as is. |
702 | if (isScalarType(type)) |
703 | return type; |
704 | |
705 | // Return builtin vectors of integer and floating-point types as is. |
706 | if (Type vectorType = getVectorTypeForAttr(type)) |
707 | return vectorType; |
708 | |
709 | // Multi-dimensional array types are converted to tensors or vectors, |
710 | // depending on the innermost type being a scalar or a vector. |
711 | SmallVector<int64_t> arrayShape; |
712 | while (auto arrayType = dyn_cast<LLVMArrayType>(type)) { |
713 | arrayShape.push_back(Elt: arrayType.getNumElements()); |
714 | type = arrayType.getElementType(); |
715 | } |
716 | if (isScalarType(type)) |
717 | return RankedTensorType::get(arrayShape, type); |
718 | return getVectorTypeForAttr(type, arrayShape); |
719 | } |
720 | |
721 | /// Returns an integer or float attribute for the provided scalar constant |
722 | /// `constScalar` or nullptr if the conversion fails. |
723 | static TypedAttr getScalarConstantAsAttr(OpBuilder &builder, |
724 | llvm::Constant *constScalar) { |
725 | MLIRContext *context = builder.getContext(); |
726 | |
727 | // Convert scalar intergers. |
728 | if (auto *constInt = dyn_cast<llvm::ConstantInt>(Val: constScalar)) { |
729 | return builder.getIntegerAttr( |
730 | IntegerType::get(context, constInt->getBitWidth()), |
731 | constInt->getValue()); |
732 | } |
733 | |
734 | // Convert scalar floats. |
735 | if (auto *constFloat = dyn_cast<llvm::ConstantFP>(Val: constScalar)) { |
736 | llvm::Type *type = constFloat->getType(); |
737 | FloatType floatType = |
738 | type->isBFloatTy() |
739 | ? FloatType::getBF16(ctx: context) |
740 | : LLVM::detail::getFloatType(context, width: type->getScalarSizeInBits()); |
741 | if (!floatType) { |
742 | emitError(UnknownLoc::get(builder.getContext())) |
743 | << "unexpected floating-point type" ; |
744 | return {}; |
745 | } |
746 | return builder.getFloatAttr(floatType, constFloat->getValueAPF()); |
747 | } |
748 | return {}; |
749 | } |
750 | |
751 | /// Returns an integer or float attribute array for the provided constant |
752 | /// sequence `constSequence` or nullptr if the conversion fails. |
753 | static SmallVector<Attribute> |
754 | getSequenceConstantAsAttrs(OpBuilder &builder, |
755 | llvm::ConstantDataSequential *constSequence) { |
756 | SmallVector<Attribute> elementAttrs; |
757 | elementAttrs.reserve(N: constSequence->getNumElements()); |
758 | for (auto idx : llvm::seq<int64_t>(Begin: 0, End: constSequence->getNumElements())) { |
759 | llvm::Constant *constElement = constSequence->getElementAsConstant(i: idx); |
760 | elementAttrs.push_back(getScalarConstantAsAttr(builder, constElement)); |
761 | } |
762 | return elementAttrs; |
763 | } |
764 | |
765 | Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) { |
766 | // Convert scalar constants. |
767 | if (Attribute scalarAttr = getScalarConstantAsAttr(builder, constant)) |
768 | return scalarAttr; |
769 | |
770 | // Convert function references. |
771 | if (auto *func = dyn_cast<llvm::Function>(constant)) |
772 | return SymbolRefAttr::get(builder.getContext(), func->getName()); |
773 | |
774 | // Returns the static shape of the provided type if possible. |
775 | auto getConstantShape = [&](llvm::Type *type) { |
776 | return llvm::dyn_cast_if_present<ShapedType>( |
777 | getBuiltinTypeForAttr(convertType(type))); |
778 | }; |
779 | |
780 | // Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte |
781 | // integer or half/bfloat/float/double values. |
782 | if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(Val: constant)) { |
783 | if (constArray->isString()) |
784 | return builder.getStringAttr(constArray->getAsString()); |
785 | auto shape = getConstantShape(constArray->getType()); |
786 | if (!shape) |
787 | return {}; |
788 | // Convert splat constants to splat elements attributes. |
789 | auto *constVector = dyn_cast<llvm::ConstantDataVector>(Val: constant); |
790 | if (constVector && constVector->isSplat()) { |
791 | // A vector is guaranteed to have at least size one. |
792 | Attribute splatAttr = getScalarConstantAsAttr( |
793 | builder, constVector->getElementAsConstant(0)); |
794 | return SplatElementsAttr::get(shape, splatAttr); |
795 | } |
796 | // Convert non-splat constants to dense elements attributes. |
797 | SmallVector<Attribute> elementAttrs = |
798 | getSequenceConstantAsAttrs(builder, constSequence: constArray); |
799 | return DenseElementsAttr::get(shape, elementAttrs); |
800 | } |
801 | |
802 | // Convert multi-dimensional constant aggregates that store all kinds of |
803 | // integer and floating-point types. |
804 | if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(Val: constant)) { |
805 | auto shape = getConstantShape(constAggregate->getType()); |
806 | if (!shape) |
807 | return {}; |
808 | // Collect the aggregate elements in depths first order. |
809 | SmallVector<Attribute> elementAttrs; |
810 | SmallVector<llvm::Constant *> workList = {constAggregate}; |
811 | while (!workList.empty()) { |
812 | llvm::Constant *current = workList.pop_back_val(); |
813 | // Append any nested aggregates in reverse order to ensure the head |
814 | // element of the nested aggregates is at the back of the work list. |
815 | if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(Val: current)) { |
816 | for (auto idx : |
817 | reverse(C: llvm::seq<int64_t>(Begin: 0, End: constAggregate->getNumOperands()))) |
818 | workList.push_back(Elt: constAggregate->getAggregateElement(Elt: idx)); |
819 | continue; |
820 | } |
821 | // Append the elements of nested constant arrays or vectors that store |
822 | // 1/2/4/8-byte integer or half/bfloat/float/double values. |
823 | if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(Val: current)) { |
824 | SmallVector<Attribute> attrs = |
825 | getSequenceConstantAsAttrs(builder, constSequence: constArray); |
826 | elementAttrs.append(in_start: attrs.begin(), in_end: attrs.end()); |
827 | continue; |
828 | } |
829 | // Append nested scalar constants that store all kinds of integer and |
830 | // floating-point types. |
831 | if (Attribute scalarAttr = getScalarConstantAsAttr(builder, current)) { |
832 | elementAttrs.push_back(Elt: scalarAttr); |
833 | continue; |
834 | } |
835 | // Bail if the aggregate contains a unsupported constant type such as a |
836 | // constant expression. |
837 | return {}; |
838 | } |
839 | return DenseElementsAttr::get(shape, elementAttrs); |
840 | } |
841 | |
842 | // Convert zero aggregates. |
843 | if (auto *constZero = dyn_cast<llvm::ConstantAggregateZero>(Val: constant)) { |
844 | auto shape = llvm::dyn_cast_if_present<ShapedType>( |
845 | getBuiltinTypeForAttr(convertType(constZero->getType()))); |
846 | if (!shape) |
847 | return {}; |
848 | // Convert zero aggregates with a static shape to splat elements attributes. |
849 | Attribute splatAttr = builder.getZeroAttr(type: shape.getElementType()); |
850 | assert(splatAttr && "expected non-null zero attribute for scalar types" ); |
851 | return SplatElementsAttr::get(shape, splatAttr); |
852 | } |
853 | return {}; |
854 | } |
855 | |
856 | LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { |
857 | // Insert the global after the last one or at the start of the module. |
858 | OpBuilder::InsertionGuard guard(builder); |
859 | if (!globalInsertionOp) |
860 | builder.setInsertionPointToStart(mlirModule.getBody()); |
861 | else |
862 | builder.setInsertionPointAfter(globalInsertionOp); |
863 | |
864 | Attribute valueAttr; |
865 | if (globalVar->hasInitializer()) |
866 | valueAttr = getConstantAsAttr(constant: globalVar->getInitializer()); |
867 | Type type = convertType(type: globalVar->getValueType()); |
868 | |
869 | uint64_t alignment = 0; |
870 | llvm::MaybeAlign maybeAlign = globalVar->getAlign(); |
871 | if (maybeAlign.has_value()) { |
872 | llvm::Align align = *maybeAlign; |
873 | alignment = align.value(); |
874 | } |
875 | |
876 | // Get the global expression associated with this global variable and convert |
877 | // it. |
878 | DIGlobalVariableExpressionAttr globalExpressionAttr; |
879 | SmallVector<llvm::DIGlobalVariableExpression *> globalExpressions; |
880 | globalVar->getDebugInfo(GVs&: globalExpressions); |
881 | |
882 | // There should only be a single global expression. |
883 | if (!globalExpressions.empty()) |
884 | globalExpressionAttr = |
885 | debugImporter->translateGlobalVariableExpression(globalExpressions[0]); |
886 | |
887 | GlobalOp globalOp = builder.create<GlobalOp>( |
888 | mlirModule.getLoc(), type, globalVar->isConstant(), |
889 | convertLinkageFromLLVM(globalVar->getLinkage()), globalVar->getName(), |
890 | valueAttr, alignment, /*addr_space=*/globalVar->getAddressSpace(), |
891 | /*dso_local=*/globalVar->isDSOLocal(), |
892 | /*thread_local=*/globalVar->isThreadLocal(), /*comdat=*/SymbolRefAttr(), |
893 | /*attrs=*/ArrayRef<NamedAttribute>(), /*dbgExpr=*/globalExpressionAttr); |
894 | globalInsertionOp = globalOp; |
895 | |
896 | if (globalVar->hasInitializer() && !valueAttr) { |
897 | clearRegionState(); |
898 | Block *block = builder.createBlock(&globalOp.getInitializerRegion()); |
899 | setConstantInsertionPointToStart(block); |
900 | FailureOr<Value> initializer = |
901 | convertConstantExpr(constant: globalVar->getInitializer()); |
902 | if (failed(result: initializer)) |
903 | return failure(); |
904 | builder.create<ReturnOp>(globalOp.getLoc(), *initializer); |
905 | } |
906 | if (globalVar->hasAtLeastLocalUnnamedAddr()) { |
907 | globalOp.setUnnamedAddr( |
908 | convertUnnamedAddrFromLLVM(globalVar->getUnnamedAddr())); |
909 | } |
910 | if (globalVar->hasSection()) |
911 | globalOp.setSection(globalVar->getSection()); |
912 | globalOp.setVisibility_( |
913 | convertVisibilityFromLLVM(globalVar->getVisibility())); |
914 | |
915 | if (globalVar->hasComdat()) |
916 | globalOp.setComdatAttr(comdatMapping.lookup(globalVar->getComdat())); |
917 | |
918 | return success(); |
919 | } |
920 | |
921 | LogicalResult |
922 | ModuleImport::convertGlobalCtorsAndDtors(llvm::GlobalVariable *globalVar) { |
923 | if (!globalVar->hasInitializer() || !globalVar->hasAppendingLinkage()) |
924 | return failure(); |
925 | auto *initializer = |
926 | dyn_cast<llvm::ConstantArray>(Val: globalVar->getInitializer()); |
927 | if (!initializer) |
928 | return failure(); |
929 | |
930 | SmallVector<Attribute> funcs; |
931 | SmallVector<int32_t> priorities; |
932 | for (llvm::Value *operand : initializer->operands()) { |
933 | auto *aggregate = dyn_cast<llvm::ConstantAggregate>(Val: operand); |
934 | if (!aggregate || aggregate->getNumOperands() != 3) |
935 | return failure(); |
936 | |
937 | auto *priority = dyn_cast<llvm::ConstantInt>(Val: aggregate->getOperand(i_nocapture: 0)); |
938 | auto *func = dyn_cast<llvm::Function>(Val: aggregate->getOperand(i_nocapture: 1)); |
939 | auto *data = dyn_cast<llvm::Constant>(Val: aggregate->getOperand(i_nocapture: 2)); |
940 | if (!priority || !func || !data) |
941 | return failure(); |
942 | |
943 | // GlobalCtorsOps and GlobalDtorsOps do not support non-null data fields. |
944 | if (!data->isNullValue()) |
945 | return failure(); |
946 | |
947 | funcs.push_back(FlatSymbolRefAttr::get(ctx: context, value: func->getName())); |
948 | priorities.push_back(Elt: priority->getValue().getZExtValue()); |
949 | } |
950 | |
951 | OpBuilder::InsertionGuard guard(builder); |
952 | if (!globalInsertionOp) |
953 | builder.setInsertionPointToStart(mlirModule.getBody()); |
954 | else |
955 | builder.setInsertionPointAfter(globalInsertionOp); |
956 | |
957 | if (globalVar->getName() == getGlobalCtorsVarName()) { |
958 | globalInsertionOp = builder.create<LLVM::GlobalCtorsOp>( |
959 | mlirModule.getLoc(), builder.getArrayAttr(funcs), |
960 | builder.getI32ArrayAttr(priorities)); |
961 | return success(); |
962 | } |
963 | globalInsertionOp = builder.create<LLVM::GlobalDtorsOp>( |
964 | mlirModule.getLoc(), builder.getArrayAttr(funcs), |
965 | builder.getI32ArrayAttr(priorities)); |
966 | return success(); |
967 | } |
968 | |
969 | SetVector<llvm::Constant *> |
970 | ModuleImport::getConstantsToConvert(llvm::Constant *constant) { |
971 | // Return the empty set if the constant has been translated before. |
972 | if (valueMapping.contains(Val: constant)) |
973 | return {}; |
974 | |
975 | // Traverse the constants in post-order and stop the traversal if a constant |
976 | // already has a `valueMapping` from an earlier constant translation or if the |
977 | // constant is traversed a second time. |
978 | SetVector<llvm::Constant *> orderedSet; |
979 | SetVector<llvm::Constant *> workList; |
980 | DenseMap<llvm::Constant *, SmallVector<llvm::Constant *>> adjacencyLists; |
981 | workList.insert(X: constant); |
982 | while (!workList.empty()) { |
983 | llvm::Constant *current = workList.back(); |
984 | // Collect all dependencies of the current constant and add them to the |
985 | // adjacency list if none has been computed before. |
986 | auto adjacencyIt = adjacencyLists.find(Val: current); |
987 | if (adjacencyIt == adjacencyLists.end()) { |
988 | adjacencyIt = adjacencyLists.try_emplace(Key: current).first; |
989 | // Add all constant operands to the adjacency list and skip any other |
990 | // values such as basic block addresses. |
991 | for (llvm::Value *operand : current->operands()) |
992 | if (auto *constDependency = dyn_cast<llvm::Constant>(Val: operand)) |
993 | adjacencyIt->getSecond().push_back(Elt: constDependency); |
994 | // Use the getElementValue method to add the dependencies of zero |
995 | // initialized aggregate constants since they do not take any operands. |
996 | if (auto *constAgg = dyn_cast<llvm::ConstantAggregateZero>(Val: current)) { |
997 | unsigned numElements = constAgg->getElementCount().getFixedValue(); |
998 | for (unsigned i = 0, e = numElements; i != e; ++i) |
999 | adjacencyIt->getSecond().push_back(Elt: constAgg->getElementValue(Idx: i)); |
1000 | } |
1001 | } |
1002 | // Add the current constant to the `orderedSet` of the traversed nodes if |
1003 | // all its dependencies have been traversed before. Additionally, remove the |
1004 | // constant from the `workList` and continue the traversal. |
1005 | if (adjacencyIt->getSecond().empty()) { |
1006 | orderedSet.insert(X: current); |
1007 | workList.pop_back(); |
1008 | continue; |
1009 | } |
1010 | // Add the next dependency from the adjacency list to the `workList` and |
1011 | // continue the traversal. Remove the dependency from the adjacency list to |
1012 | // mark that it has been processed. Only enqueue the dependency if it has no |
1013 | // `valueMapping` from an earlier translation and if it has not been |
1014 | // enqueued before. |
1015 | llvm::Constant *dependency = adjacencyIt->getSecond().pop_back_val(); |
1016 | if (valueMapping.contains(Val: dependency) || workList.contains(key: dependency) || |
1017 | orderedSet.contains(key: dependency)) |
1018 | continue; |
1019 | workList.insert(X: dependency); |
1020 | } |
1021 | |
1022 | return orderedSet; |
1023 | } |
1024 | |
1025 | FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) { |
1026 | Location loc = UnknownLoc::get(context); |
1027 | |
1028 | // Convert constants that can be represented as attributes. |
1029 | if (Attribute attr = getConstantAsAttr(constant)) { |
1030 | Type type = convertType(type: constant->getType()); |
1031 | if (auto symbolRef = dyn_cast<FlatSymbolRefAttr>(attr)) { |
1032 | return builder.create<AddressOfOp>(loc, type, symbolRef.getValue()) |
1033 | .getResult(); |
1034 | } |
1035 | return builder.create<ConstantOp>(loc, type, attr).getResult(); |
1036 | } |
1037 | |
1038 | // Convert null pointer constants. |
1039 | if (auto *nullPtr = dyn_cast<llvm::ConstantPointerNull>(Val: constant)) { |
1040 | Type type = convertType(type: nullPtr->getType()); |
1041 | return builder.create<ZeroOp>(loc, type).getResult(); |
1042 | } |
1043 | |
1044 | // Convert none token constants. |
1045 | if (isa<llvm::ConstantTokenNone>(Val: constant)) { |
1046 | return builder.create<NoneTokenOp>(loc).getResult(); |
1047 | } |
1048 | |
1049 | // Convert poison. |
1050 | if (auto *poisonVal = dyn_cast<llvm::PoisonValue>(Val: constant)) { |
1051 | Type type = convertType(type: poisonVal->getType()); |
1052 | return builder.create<PoisonOp>(loc, type).getResult(); |
1053 | } |
1054 | |
1055 | // Convert undef. |
1056 | if (auto *undefVal = dyn_cast<llvm::UndefValue>(Val: constant)) { |
1057 | Type type = convertType(type: undefVal->getType()); |
1058 | return builder.create<UndefOp>(loc, type).getResult(); |
1059 | } |
1060 | |
1061 | // Convert global variable accesses. |
1062 | if (auto *globalVar = dyn_cast<llvm::GlobalVariable>(Val: constant)) { |
1063 | Type type = convertType(type: globalVar->getType()); |
1064 | auto symbolRef = FlatSymbolRefAttr::get(ctx: context, value: globalVar->getName()); |
1065 | return builder.create<AddressOfOp>(loc, type, symbolRef).getResult(); |
1066 | } |
1067 | |
1068 | // Convert constant expressions. |
1069 | if (auto *constExpr = dyn_cast<llvm::ConstantExpr>(Val: constant)) { |
1070 | // Convert the constant expression to a temporary LLVM instruction and |
1071 | // translate it using the `processInstruction` method. Delete the |
1072 | // instruction after the translation and remove it from `valueMapping`, |
1073 | // since later calls to `getAsInstruction` may return the same address |
1074 | // resulting in a conflicting `valueMapping` entry. |
1075 | llvm::Instruction *inst = constExpr->getAsInstruction(); |
1076 | auto guard = llvm::make_scope_exit(F: [&]() { |
1077 | assert(!noResultOpMapping.contains(inst) && |
1078 | "expected constant expression to return a result" ); |
1079 | valueMapping.erase(Val: inst); |
1080 | inst->deleteValue(); |
1081 | }); |
1082 | // Note: `processInstruction` does not call `convertConstant` recursively |
1083 | // since all constant dependencies have been converted before. |
1084 | assert(llvm::all_of(inst->operands(), [&](llvm::Value *value) { |
1085 | return valueMapping.contains(value); |
1086 | })); |
1087 | if (failed(result: processInstruction(inst))) |
1088 | return failure(); |
1089 | return lookupValue(value: inst); |
1090 | } |
1091 | |
1092 | // Convert aggregate constants. |
1093 | if (isa<llvm::ConstantAggregate>(Val: constant) || |
1094 | isa<llvm::ConstantAggregateZero>(Val: constant)) { |
1095 | // Lookup the aggregate elements that have been converted before. |
1096 | SmallVector<Value> elementValues; |
1097 | if (auto *constAgg = dyn_cast<llvm::ConstantAggregate>(Val: constant)) { |
1098 | elementValues.reserve(N: constAgg->getNumOperands()); |
1099 | for (llvm::Value *operand : constAgg->operands()) |
1100 | elementValues.push_back(Elt: lookupValue(value: operand)); |
1101 | } |
1102 | if (auto *constAgg = dyn_cast<llvm::ConstantAggregateZero>(Val: constant)) { |
1103 | unsigned numElements = constAgg->getElementCount().getFixedValue(); |
1104 | elementValues.reserve(N: numElements); |
1105 | for (unsigned i = 0, e = numElements; i != e; ++i) |
1106 | elementValues.push_back(Elt: lookupValue(value: constAgg->getElementValue(Idx: i))); |
1107 | } |
1108 | assert(llvm::count(elementValues, nullptr) == 0 && |
1109 | "expected all elements have been converted before" ); |
1110 | |
1111 | // Generate an UndefOp as root value and insert the aggregate elements. |
1112 | Type rootType = convertType(type: constant->getType()); |
1113 | bool isArrayOrStruct = isa<LLVMArrayType, LLVMStructType>(rootType); |
1114 | assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) && |
1115 | "unrecognized aggregate type" ); |
1116 | Value root = builder.create<UndefOp>(loc, rootType); |
1117 | for (const auto &it : llvm::enumerate(First&: elementValues)) { |
1118 | if (isArrayOrStruct) { |
1119 | root = builder.create<InsertValueOp>(loc, root, it.value(), it.index()); |
1120 | } else { |
1121 | Attribute indexAttr = builder.getI32IntegerAttr(it.index()); |
1122 | Value indexValue = |
1123 | builder.create<ConstantOp>(loc, builder.getI32Type(), indexAttr); |
1124 | root = builder.create<InsertElementOp>(loc, rootType, root, it.value(), |
1125 | indexValue); |
1126 | } |
1127 | } |
1128 | return root; |
1129 | } |
1130 | |
1131 | if (auto *constTargetNone = dyn_cast<llvm::ConstantTargetNone>(Val: constant)) { |
1132 | LLVMTargetExtType targetExtType = |
1133 | cast<LLVMTargetExtType>(convertType(constTargetNone->getType())); |
1134 | assert(targetExtType.hasProperty(LLVMTargetExtType::HasZeroInit) && |
1135 | "target extension type does not support zero-initialization" ); |
1136 | // Create llvm.mlir.zero operation to represent zero-initialization of |
1137 | // target extension type. |
1138 | return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes(); |
1139 | } |
1140 | |
1141 | StringRef error = "" ; |
1142 | if (isa<llvm::BlockAddress>(Val: constant)) |
1143 | error = " since blockaddress(...) is unsupported" ; |
1144 | |
1145 | return emitError(loc) << "unhandled constant: " << diag(value: *constant) << error; |
1146 | } |
1147 | |
1148 | FailureOr<Value> ModuleImport::convertConstantExpr(llvm::Constant *constant) { |
1149 | // Only call the function for constants that have not been translated before |
1150 | // since it updates the constant insertion point assuming the converted |
1151 | // constant has been introduced at the end of the constant section. |
1152 | assert(!valueMapping.contains(constant) && |
1153 | "expected constant has not been converted before" ); |
1154 | assert(constantInsertionBlock && |
1155 | "expected the constant insertion block to be non-null" ); |
1156 | |
1157 | // Insert the constant after the last one or at the start of the entry block. |
1158 | OpBuilder::InsertionGuard guard(builder); |
1159 | if (!constantInsertionOp) |
1160 | builder.setInsertionPointToStart(constantInsertionBlock); |
1161 | else |
1162 | builder.setInsertionPointAfter(constantInsertionOp); |
1163 | |
1164 | // Convert all constants of the expression and add them to `valueMapping`. |
1165 | SetVector<llvm::Constant *> constantsToConvert = |
1166 | getConstantsToConvert(constant); |
1167 | for (llvm::Constant *constantToConvert : constantsToConvert) { |
1168 | FailureOr<Value> converted = convertConstant(constant: constantToConvert); |
1169 | if (failed(result: converted)) |
1170 | return failure(); |
1171 | mapValue(llvm: constantToConvert, mlir: *converted); |
1172 | } |
1173 | |
1174 | // Update the constant insertion point and return the converted constant. |
1175 | Value result = lookupValue(value: constant); |
1176 | constantInsertionOp = result.getDefiningOp(); |
1177 | return result; |
1178 | } |
1179 | |
1180 | FailureOr<Value> ModuleImport::convertValue(llvm::Value *value) { |
1181 | assert(!isa<llvm::MetadataAsValue>(value) && |
1182 | "expected value to not be metadata" ); |
1183 | |
1184 | // Return the mapped value if it has been converted before. |
1185 | auto it = valueMapping.find(Val: value); |
1186 | if (it != valueMapping.end()) |
1187 | return it->getSecond(); |
1188 | |
1189 | // Convert constants such as immediate values that have no mapping yet. |
1190 | if (auto *constant = dyn_cast<llvm::Constant>(Val: value)) |
1191 | return convertConstantExpr(constant); |
1192 | |
1193 | Location loc = UnknownLoc::get(context); |
1194 | if (auto *inst = dyn_cast<llvm::Instruction>(Val: value)) |
1195 | loc = translateLoc(loc: inst->getDebugLoc()); |
1196 | return emitError(loc) << "unhandled value: " << diag(value: *value); |
1197 | } |
1198 | |
1199 | FailureOr<Value> ModuleImport::convertMetadataValue(llvm::Value *value) { |
1200 | // A value may be wrapped as metadata, for example, when passed to a debug |
1201 | // intrinsic. Unwrap these values before the conversion. |
1202 | auto *nodeAsVal = dyn_cast<llvm::MetadataAsValue>(Val: value); |
1203 | if (!nodeAsVal) |
1204 | return failure(); |
1205 | auto *node = dyn_cast<llvm::ValueAsMetadata>(Val: nodeAsVal->getMetadata()); |
1206 | if (!node) |
1207 | return failure(); |
1208 | value = node->getValue(); |
1209 | |
1210 | // Return the mapped value if it has been converted before. |
1211 | auto it = valueMapping.find(Val: value); |
1212 | if (it != valueMapping.end()) |
1213 | return it->getSecond(); |
1214 | |
1215 | // Convert constants such as immediate values that have no mapping yet. |
1216 | if (auto *constant = dyn_cast<llvm::Constant>(Val: value)) |
1217 | return convertConstantExpr(constant); |
1218 | return failure(); |
1219 | } |
1220 | |
1221 | FailureOr<SmallVector<Value>> |
1222 | ModuleImport::convertValues(ArrayRef<llvm::Value *> values) { |
1223 | SmallVector<Value> remapped; |
1224 | remapped.reserve(N: values.size()); |
1225 | for (llvm::Value *value : values) { |
1226 | FailureOr<Value> converted = convertValue(value); |
1227 | if (failed(result: converted)) |
1228 | return failure(); |
1229 | remapped.push_back(Elt: *converted); |
1230 | } |
1231 | return remapped; |
1232 | } |
1233 | |
1234 | LogicalResult ModuleImport::convertIntrinsicArguments( |
1235 | ArrayRef<llvm::Value *> values, ArrayRef<unsigned> immArgPositions, |
1236 | ArrayRef<StringLiteral> immArgAttrNames, SmallVectorImpl<Value> &valuesOut, |
1237 | SmallVectorImpl<NamedAttribute> &attrsOut) { |
1238 | assert(immArgPositions.size() == immArgAttrNames.size() && |
1239 | "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal " |
1240 | "length" ); |
1241 | |
1242 | SmallVector<llvm::Value *> operands(values); |
1243 | for (auto [immArgPos, immArgName] : |
1244 | llvm::zip(t&: immArgPositions, u&: immArgAttrNames)) { |
1245 | auto &value = operands[immArgPos]; |
1246 | auto *constant = llvm::cast<llvm::Constant>(Val: value); |
1247 | auto attr = getScalarConstantAsAttr(builder, constant); |
1248 | assert(attr && attr.getType().isIntOrFloat() && |
1249 | "expected immarg to be float or integer constant" ); |
1250 | auto nameAttr = StringAttr::get(attr.getContext(), immArgName); |
1251 | attrsOut.push_back(Elt: {nameAttr, attr}); |
1252 | // Mark matched attribute values as null (so they can be removed below). |
1253 | value = nullptr; |
1254 | } |
1255 | |
1256 | for (llvm::Value *value : operands) { |
1257 | if (!value) |
1258 | continue; |
1259 | auto mlirValue = convertValue(value); |
1260 | if (failed(result: mlirValue)) |
1261 | return failure(); |
1262 | valuesOut.push_back(Elt: *mlirValue); |
1263 | } |
1264 | |
1265 | return success(); |
1266 | } |
1267 | |
1268 | IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) { |
1269 | IntegerAttr integerAttr; |
1270 | FailureOr<Value> converted = convertValue(value); |
1271 | bool success = succeeded(result: converted) && |
1272 | matchPattern(*converted, m_Constant(&integerAttr)); |
1273 | assert(success && "expected a constant integer value" ); |
1274 | (void)success; |
1275 | return integerAttr; |
1276 | } |
1277 | |
1278 | FloatAttr ModuleImport::matchFloatAttr(llvm::Value *value) { |
1279 | FloatAttr floatAttr; |
1280 | FailureOr<Value> converted = convertValue(value); |
1281 | bool success = |
1282 | succeeded(result: converted) && matchPattern(*converted, m_Constant(&floatAttr)); |
1283 | assert(success && "expected a constant float value" ); |
1284 | (void)success; |
1285 | return floatAttr; |
1286 | } |
1287 | |
1288 | DILocalVariableAttr ModuleImport::matchLocalVariableAttr(llvm::Value *value) { |
1289 | auto *nodeAsVal = cast<llvm::MetadataAsValue>(Val: value); |
1290 | auto *node = cast<llvm::DILocalVariable>(Val: nodeAsVal->getMetadata()); |
1291 | return debugImporter->translate(node); |
1292 | } |
1293 | |
1294 | DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) { |
1295 | auto *nodeAsVal = cast<llvm::MetadataAsValue>(Val: value); |
1296 | auto *node = cast<llvm::DILabel>(Val: nodeAsVal->getMetadata()); |
1297 | return debugImporter->translate(node); |
1298 | } |
1299 | |
1300 | FPExceptionBehaviorAttr |
1301 | ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) { |
1302 | auto *metadata = cast<llvm::MetadataAsValue>(Val: value); |
1303 | auto *mdstr = cast<llvm::MDString>(Val: metadata->getMetadata()); |
1304 | std::optional<llvm::fp::ExceptionBehavior> optLLVM = |
1305 | llvm::convertStrToExceptionBehavior(mdstr->getString()); |
1306 | assert(optLLVM && "Expecting FP exception behavior" ); |
1307 | return builder.getAttr<FPExceptionBehaviorAttr>( |
1308 | convertFPExceptionBehaviorFromLLVM(*optLLVM)); |
1309 | } |
1310 | |
1311 | RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) { |
1312 | auto *metadata = cast<llvm::MetadataAsValue>(Val: value); |
1313 | auto *mdstr = cast<llvm::MDString>(Val: metadata->getMetadata()); |
1314 | std::optional<llvm::RoundingMode> optLLVM = |
1315 | llvm::convertStrToRoundingMode(mdstr->getString()); |
1316 | assert(optLLVM && "Expecting rounding mode" ); |
1317 | return builder.getAttr<RoundingModeAttr>( |
1318 | convertRoundingModeFromLLVM(*optLLVM)); |
1319 | } |
1320 | |
1321 | FailureOr<SmallVector<AliasScopeAttr>> |
1322 | ModuleImport::matchAliasScopeAttrs(llvm::Value *value) { |
1323 | auto *nodeAsVal = cast<llvm::MetadataAsValue>(Val: value); |
1324 | auto *node = cast<llvm::MDNode>(Val: nodeAsVal->getMetadata()); |
1325 | return lookupAliasScopeAttrs(node); |
1326 | } |
1327 | |
1328 | Location ModuleImport::translateLoc(llvm::DILocation *loc) { |
1329 | return debugImporter->translateLoc(loc); |
1330 | } |
1331 | |
1332 | LogicalResult |
1333 | ModuleImport::convertBranchArgs(llvm::Instruction *branch, |
1334 | llvm::BasicBlock *target, |
1335 | SmallVectorImpl<Value> &blockArguments) { |
1336 | for (auto inst = target->begin(); isa<llvm::PHINode>(Val: inst); ++inst) { |
1337 | auto *phiInst = cast<llvm::PHINode>(Val: &*inst); |
1338 | llvm::Value *value = phiInst->getIncomingValueForBlock(BB: branch->getParent()); |
1339 | FailureOr<Value> converted = convertValue(value); |
1340 | if (failed(result: converted)) |
1341 | return failure(); |
1342 | blockArguments.push_back(Elt: *converted); |
1343 | } |
1344 | return success(); |
1345 | } |
1346 | |
1347 | LogicalResult |
1348 | ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst, |
1349 | SmallVectorImpl<Type> &types, |
1350 | SmallVectorImpl<Value> &operands) { |
1351 | if (!callInst->getType()->isVoidTy()) |
1352 | types.push_back(Elt: convertType(type: callInst->getType())); |
1353 | |
1354 | if (!callInst->getCalledFunction()) { |
1355 | FailureOr<Value> called = convertValue(value: callInst->getCalledOperand()); |
1356 | if (failed(result: called)) |
1357 | return failure(); |
1358 | operands.push_back(Elt: *called); |
1359 | } |
1360 | SmallVector<llvm::Value *> args(callInst->args()); |
1361 | FailureOr<SmallVector<Value>> arguments = convertValues(values: args); |
1362 | if (failed(result: arguments)) |
1363 | return failure(); |
1364 | llvm::append_range(C&: operands, R&: *arguments); |
1365 | return success(); |
1366 | } |
1367 | |
1368 | LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) { |
1369 | if (succeeded(result: iface.convertIntrinsic(builder, inst, moduleImport&: *this))) |
1370 | return success(); |
1371 | |
1372 | Location loc = translateLoc(loc: inst->getDebugLoc()); |
1373 | return emitError(loc) << "unhandled intrinsic: " << diag(value: *inst); |
1374 | } |
1375 | |
1376 | LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { |
1377 | // Convert all instructions that do not provide an MLIR builder. |
1378 | Location loc = translateLoc(loc: inst->getDebugLoc()); |
1379 | if (inst->getOpcode() == llvm::Instruction::Br) { |
1380 | auto *brInst = cast<llvm::BranchInst>(Val: inst); |
1381 | |
1382 | SmallVector<Block *> succBlocks; |
1383 | SmallVector<SmallVector<Value>> succBlockArgs; |
1384 | for (auto i : llvm::seq<unsigned>(Begin: 0, End: brInst->getNumSuccessors())) { |
1385 | llvm::BasicBlock *succ = brInst->getSuccessor(i); |
1386 | SmallVector<Value> blockArgs; |
1387 | if (failed(result: convertBranchArgs(branch: brInst, target: succ, blockArguments&: blockArgs))) |
1388 | return failure(); |
1389 | succBlocks.push_back(Elt: lookupBlock(block: succ)); |
1390 | succBlockArgs.push_back(Elt: blockArgs); |
1391 | } |
1392 | |
1393 | if (!brInst->isConditional()) { |
1394 | auto brOp = builder.create<LLVM::BrOp>(loc, succBlockArgs.front(), |
1395 | succBlocks.front()); |
1396 | mapNoResultOp(inst, brOp); |
1397 | return success(); |
1398 | } |
1399 | FailureOr<Value> condition = convertValue(value: brInst->getCondition()); |
1400 | if (failed(result: condition)) |
1401 | return failure(); |
1402 | auto condBrOp = builder.create<LLVM::CondBrOp>( |
1403 | loc, *condition, succBlocks.front(), succBlockArgs.front(), |
1404 | succBlocks.back(), succBlockArgs.back()); |
1405 | mapNoResultOp(inst, condBrOp); |
1406 | return success(); |
1407 | } |
1408 | if (inst->getOpcode() == llvm::Instruction::Switch) { |
1409 | auto *swInst = cast<llvm::SwitchInst>(Val: inst); |
1410 | // Process the condition value. |
1411 | FailureOr<Value> condition = convertValue(value: swInst->getCondition()); |
1412 | if (failed(result: condition)) |
1413 | return failure(); |
1414 | SmallVector<Value> defaultBlockArgs; |
1415 | // Process the default case. |
1416 | llvm::BasicBlock *defaultBB = swInst->getDefaultDest(); |
1417 | if (failed(result: convertBranchArgs(branch: swInst, target: defaultBB, blockArguments&: defaultBlockArgs))) |
1418 | return failure(); |
1419 | |
1420 | // Process the cases. |
1421 | unsigned numCases = swInst->getNumCases(); |
1422 | SmallVector<SmallVector<Value>> caseOperands(numCases); |
1423 | SmallVector<ValueRange> caseOperandRefs(numCases); |
1424 | SmallVector<APInt> caseValues(numCases); |
1425 | SmallVector<Block *> caseBlocks(numCases); |
1426 | for (const auto &it : llvm::enumerate(First: swInst->cases())) { |
1427 | const llvm::SwitchInst::CaseHandle &caseHandle = it.value(); |
1428 | llvm::BasicBlock *succBB = caseHandle.getCaseSuccessor(); |
1429 | if (failed(result: convertBranchArgs(branch: swInst, target: succBB, blockArguments&: caseOperands[it.index()]))) |
1430 | return failure(); |
1431 | caseOperandRefs[it.index()] = caseOperands[it.index()]; |
1432 | caseValues[it.index()] = caseHandle.getCaseValue()->getValue(); |
1433 | caseBlocks[it.index()] = lookupBlock(block: succBB); |
1434 | } |
1435 | |
1436 | auto switchOp = builder.create<SwitchOp>( |
1437 | loc, *condition, lookupBlock(defaultBB), defaultBlockArgs, caseValues, |
1438 | caseBlocks, caseOperandRefs); |
1439 | mapNoResultOp(inst, switchOp); |
1440 | return success(); |
1441 | } |
1442 | if (inst->getOpcode() == llvm::Instruction::PHI) { |
1443 | Type type = convertType(type: inst->getType()); |
1444 | mapValue(llvm: inst, mlir: builder.getInsertionBlock()->addArgument( |
1445 | type, loc: translateLoc(loc: inst->getDebugLoc()))); |
1446 | return success(); |
1447 | } |
1448 | if (inst->getOpcode() == llvm::Instruction::Call) { |
1449 | auto *callInst = cast<llvm::CallInst>(Val: inst); |
1450 | |
1451 | SmallVector<Type> types; |
1452 | SmallVector<Value> operands; |
1453 | if (failed(result: convertCallTypeAndOperands(callInst, types, operands))) |
1454 | return failure(); |
1455 | |
1456 | auto funcTy = |
1457 | dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType())); |
1458 | if (!funcTy) |
1459 | return failure(); |
1460 | |
1461 | CallOp callOp; |
1462 | |
1463 | if (llvm::Function *callee = callInst->getCalledFunction()) { |
1464 | callOp = builder.create<CallOp>( |
1465 | loc, funcTy, SymbolRefAttr::get(context, callee->getName()), |
1466 | operands); |
1467 | } else { |
1468 | callOp = builder.create<CallOp>(loc, funcTy, operands); |
1469 | } |
1470 | callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv())); |
1471 | setFastmathFlagsAttr(inst, op: callOp); |
1472 | if (!callInst->getType()->isVoidTy()) |
1473 | mapValue(inst, callOp.getResult()); |
1474 | else |
1475 | mapNoResultOp(inst, callOp); |
1476 | return success(); |
1477 | } |
1478 | if (inst->getOpcode() == llvm::Instruction::LandingPad) { |
1479 | auto *lpInst = cast<llvm::LandingPadInst>(Val: inst); |
1480 | |
1481 | SmallVector<Value> operands; |
1482 | operands.reserve(N: lpInst->getNumClauses()); |
1483 | for (auto i : llvm::seq<unsigned>(Begin: 0, End: lpInst->getNumClauses())) { |
1484 | FailureOr<Value> operand = convertValue(value: lpInst->getClause(Idx: i)); |
1485 | if (failed(result: operand)) |
1486 | return failure(); |
1487 | operands.push_back(Elt: *operand); |
1488 | } |
1489 | |
1490 | Type type = convertType(type: lpInst->getType()); |
1491 | auto lpOp = |
1492 | builder.create<LandingpadOp>(loc, type, lpInst->isCleanup(), operands); |
1493 | mapValue(inst, lpOp); |
1494 | return success(); |
1495 | } |
1496 | if (inst->getOpcode() == llvm::Instruction::Invoke) { |
1497 | auto *invokeInst = cast<llvm::InvokeInst>(Val: inst); |
1498 | |
1499 | SmallVector<Type> types; |
1500 | SmallVector<Value> operands; |
1501 | if (failed(result: convertCallTypeAndOperands(callInst: invokeInst, types, operands))) |
1502 | return failure(); |
1503 | |
1504 | // Check whether the invoke result is an argument to the normal destination |
1505 | // block. |
1506 | bool invokeResultUsedInPhi = llvm::any_of( |
1507 | Range: invokeInst->getNormalDest()->phis(), P: [&](const llvm::PHINode &phi) { |
1508 | return phi.getIncomingValueForBlock(BB: invokeInst->getParent()) == |
1509 | invokeInst; |
1510 | }); |
1511 | |
1512 | Block *normalDest = lookupBlock(block: invokeInst->getNormalDest()); |
1513 | Block *directNormalDest = normalDest; |
1514 | if (invokeResultUsedInPhi) { |
1515 | // The invoke result cannot be an argument to the normal destination |
1516 | // block, as that would imply using the invoke operation result in its |
1517 | // definition, so we need to create a dummy block to serve as an |
1518 | // intermediate destination. |
1519 | OpBuilder::InsertionGuard g(builder); |
1520 | directNormalDest = builder.createBlock(insertBefore: normalDest); |
1521 | } |
1522 | |
1523 | SmallVector<Value> unwindArgs; |
1524 | if (failed(result: convertBranchArgs(branch: invokeInst, target: invokeInst->getUnwindDest(), |
1525 | blockArguments&: unwindArgs))) |
1526 | return failure(); |
1527 | |
1528 | auto funcTy = |
1529 | dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType())); |
1530 | if (!funcTy) |
1531 | return failure(); |
1532 | |
1533 | // Create the invoke operation. Normal destination block arguments will be |
1534 | // added later on to handle the case in which the operation result is |
1535 | // included in this list. |
1536 | InvokeOp invokeOp; |
1537 | if (llvm::Function *callee = invokeInst->getCalledFunction()) { |
1538 | invokeOp = builder.create<InvokeOp>( |
1539 | loc, funcTy, |
1540 | SymbolRefAttr::get(builder.getContext(), callee->getName()), operands, |
1541 | directNormalDest, ValueRange(), |
1542 | lookupBlock(invokeInst->getUnwindDest()), unwindArgs); |
1543 | } else { |
1544 | invokeOp = builder.create<InvokeOp>( |
1545 | loc, funcTy, /*callee=*/nullptr, operands, directNormalDest, |
1546 | ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs); |
1547 | } |
1548 | invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv())); |
1549 | if (!invokeInst->getType()->isVoidTy()) |
1550 | mapValue(inst, invokeOp.getResults().front()); |
1551 | else |
1552 | mapNoResultOp(inst, invokeOp); |
1553 | |
1554 | SmallVector<Value> normalArgs; |
1555 | if (failed(result: convertBranchArgs(branch: invokeInst, target: invokeInst->getNormalDest(), |
1556 | blockArguments&: normalArgs))) |
1557 | return failure(); |
1558 | |
1559 | if (invokeResultUsedInPhi) { |
1560 | // The dummy normal dest block will just host an unconditional branch |
1561 | // instruction to the normal destination block passing the required block |
1562 | // arguments (including the invoke operation's result). |
1563 | OpBuilder::InsertionGuard g(builder); |
1564 | builder.setInsertionPointToStart(directNormalDest); |
1565 | builder.create<LLVM::BrOp>(loc, normalArgs, normalDest); |
1566 | } else { |
1567 | // If the invoke operation's result is not a block argument to the normal |
1568 | // destination block, just add the block arguments as usual. |
1569 | assert(llvm::none_of( |
1570 | normalArgs, |
1571 | [&](Value val) { return val.getDefiningOp() == invokeOp; }) && |
1572 | "An llvm.invoke operation cannot pass its result as a block " |
1573 | "argument." ); |
1574 | invokeOp.getNormalDestOperandsMutable().append(normalArgs); |
1575 | } |
1576 | |
1577 | return success(); |
1578 | } |
1579 | if (inst->getOpcode() == llvm::Instruction::GetElementPtr) { |
1580 | auto *gepInst = cast<llvm::GetElementPtrInst>(Val: inst); |
1581 | Type sourceElementType = convertType(type: gepInst->getSourceElementType()); |
1582 | FailureOr<Value> basePtr = convertValue(value: gepInst->getOperand(i_nocapture: 0)); |
1583 | if (failed(result: basePtr)) |
1584 | return failure(); |
1585 | |
1586 | // Treat every indices as dynamic since GEPOp::build will refine those |
1587 | // indices into static attributes later. One small downside of this |
1588 | // approach is that many unused `llvm.mlir.constant` would be emitted |
1589 | // at first place. |
1590 | SmallVector<GEPArg> indices; |
1591 | for (llvm::Value *operand : llvm::drop_begin(RangeOrContainer: gepInst->operand_values())) { |
1592 | FailureOr<Value> index = convertValue(value: operand); |
1593 | if (failed(result: index)) |
1594 | return failure(); |
1595 | indices.push_back(Elt: *index); |
1596 | } |
1597 | |
1598 | Type type = convertType(type: inst->getType()); |
1599 | auto gepOp = builder.create<GEPOp>(loc, type, sourceElementType, *basePtr, |
1600 | indices, gepInst->isInBounds()); |
1601 | mapValue(inst, gepOp); |
1602 | return success(); |
1603 | } |
1604 | |
1605 | // Convert all instructions that have an mlirBuilder. |
1606 | if (succeeded(result: convertInstructionImpl(odsBuilder&: builder, inst, moduleImport&: *this, iface))) |
1607 | return success(); |
1608 | |
1609 | return emitError(loc) << "unhandled instruction: " << diag(value: *inst); |
1610 | } |
1611 | |
1612 | LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) { |
1613 | // FIXME: Support uses of SubtargetData. |
1614 | // FIXME: Add support for call / operand attributes. |
1615 | // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch, |
1616 | // callbr, vaarg, catchpad, cleanuppad instructions. |
1617 | |
1618 | // Convert LLVM intrinsics calls to MLIR intrinsics. |
1619 | if (auto *intrinsic = dyn_cast<llvm::IntrinsicInst>(Val: inst)) |
1620 | return convertIntrinsic(inst: intrinsic); |
1621 | |
1622 | // Convert all remaining LLVM instructions to MLIR operations. |
1623 | return convertInstruction(inst); |
1624 | } |
1625 | |
1626 | FlatSymbolRefAttr ModuleImport::getPersonalityAsAttr(llvm::Function *f) { |
1627 | if (!f->hasPersonalityFn()) |
1628 | return nullptr; |
1629 | |
1630 | llvm::Constant *pf = f->getPersonalityFn(); |
1631 | |
1632 | // If it directly has a name, we can use it. |
1633 | if (pf->hasName()) |
1634 | return SymbolRefAttr::get(builder.getContext(), pf->getName()); |
1635 | |
1636 | // If it doesn't have a name, currently, only function pointers that are |
1637 | // bitcast to i8* are parsed. |
1638 | if (auto *ce = dyn_cast<llvm::ConstantExpr>(Val: pf)) { |
1639 | if (ce->getOpcode() == llvm::Instruction::BitCast && |
1640 | ce->getType() == llvm::PointerType::getUnqual(C&: f->getContext())) { |
1641 | if (auto *func = dyn_cast<llvm::Function>(ce->getOperand(0))) |
1642 | return SymbolRefAttr::get(builder.getContext(), func->getName()); |
1643 | } |
1644 | } |
1645 | return FlatSymbolRefAttr(); |
1646 | } |
1647 | |
1648 | static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) { |
1649 | llvm::MemoryEffects memEffects = func->getMemoryEffects(); |
1650 | |
1651 | auto othermem = convertModRefInfoFromLLVM( |
1652 | memEffects.getModRef(Loc: llvm::MemoryEffects::Location::Other)); |
1653 | auto argMem = convertModRefInfoFromLLVM( |
1654 | memEffects.getModRef(Loc: llvm::MemoryEffects::Location::ArgMem)); |
1655 | auto inaccessibleMem = convertModRefInfoFromLLVM( |
1656 | memEffects.getModRef(Loc: llvm::MemoryEffects::Location::InaccessibleMem)); |
1657 | auto memAttr = MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem, |
1658 | inaccessibleMem); |
1659 | // Only set the attr when it does not match the default value. |
1660 | if (memAttr.isReadWrite()) |
1661 | return; |
1662 | funcOp.setMemoryAttr(memAttr); |
1663 | } |
1664 | |
1665 | // List of LLVM IR attributes that map to an explicit attribute on the MLIR |
1666 | // LLVMFuncOp. |
1667 | static constexpr std::array ExplicitAttributes{ |
1668 | StringLiteral("aarch64_pstate_sm_enabled" ), |
1669 | StringLiteral("aarch64_pstate_sm_body" ), |
1670 | StringLiteral("aarch64_pstate_sm_compatible" ), |
1671 | StringLiteral("aarch64_new_za" ), |
1672 | StringLiteral("aarch64_preserves_za" ), |
1673 | StringLiteral("aarch64_in_za" ), |
1674 | StringLiteral("aarch64_out_za" ), |
1675 | StringLiteral("aarch64_inout_za" ), |
1676 | StringLiteral("vscale_range" ), |
1677 | StringLiteral("frame-pointer" ), |
1678 | StringLiteral("target-features" ), |
1679 | StringLiteral("unsafe-fp-math" ), |
1680 | StringLiteral("no-infs-fp-math" ), |
1681 | StringLiteral("no-nans-fp-math" ), |
1682 | StringLiteral("approx-func-fp-math" ), |
1683 | StringLiteral("no-signed-zeros-fp-math" ), |
1684 | }; |
1685 | |
1686 | static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) { |
1687 | MLIRContext *context = funcOp.getContext(); |
1688 | SmallVector<Attribute> passthroughs; |
1689 | llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes( |
1690 | Index: llvm::AttributeList::AttrIndex::FunctionIndex); |
1691 | for (llvm::Attribute attr : funcAttrs) { |
1692 | // Skip the memory attribute since the LLVMFuncOp has an explicit memory |
1693 | // attribute. |
1694 | if (attr.hasAttribute(llvm::Attribute::Memory)) |
1695 | continue; |
1696 | |
1697 | // Skip invalid type attributes. |
1698 | if (attr.isTypeAttribute()) { |
1699 | emitWarning(funcOp.getLoc(), |
1700 | "type attributes on a function are invalid, skipping it" ); |
1701 | continue; |
1702 | } |
1703 | |
1704 | StringRef attrName; |
1705 | if (attr.isStringAttribute()) |
1706 | attrName = attr.getKindAsString(); |
1707 | else |
1708 | attrName = llvm::Attribute::getNameFromAttrKind(AttrKind: attr.getKindAsEnum()); |
1709 | auto keyAttr = StringAttr::get(context, attrName); |
1710 | |
1711 | // Skip attributes that map to an explicit attribute on the LLVMFuncOp. |
1712 | if (llvm::is_contained(Range: ExplicitAttributes, Element: attrName)) |
1713 | continue; |
1714 | |
1715 | if (attr.isStringAttribute()) { |
1716 | StringRef val = attr.getValueAsString(); |
1717 | if (val.empty()) { |
1718 | passthroughs.push_back(Elt: keyAttr); |
1719 | continue; |
1720 | } |
1721 | passthroughs.push_back( |
1722 | ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); |
1723 | continue; |
1724 | } |
1725 | if (attr.isIntAttribute()) { |
1726 | auto val = std::to_string(val: attr.getValueAsInt()); |
1727 | passthroughs.push_back( |
1728 | ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); |
1729 | continue; |
1730 | } |
1731 | if (attr.isEnumAttribute()) { |
1732 | passthroughs.push_back(Elt: keyAttr); |
1733 | continue; |
1734 | } |
1735 | |
1736 | llvm_unreachable("unexpected attribute kind" ); |
1737 | } |
1738 | |
1739 | if (!passthroughs.empty()) |
1740 | funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs)); |
1741 | } |
1742 | |
1743 | void ModuleImport::processFunctionAttributes(llvm::Function *func, |
1744 | LLVMFuncOp funcOp) { |
1745 | processMemoryEffects(func, funcOp); |
1746 | processPassthroughAttrs(func, funcOp); |
1747 | |
1748 | if (func->hasFnAttribute(Kind: "aarch64_pstate_sm_enabled" )) |
1749 | funcOp.setArmStreaming(true); |
1750 | else if (func->hasFnAttribute(Kind: "aarch64_pstate_sm_body" )) |
1751 | funcOp.setArmLocallyStreaming(true); |
1752 | else if (func->hasFnAttribute(Kind: "aarch64_pstate_sm_compatible" )) |
1753 | funcOp.setArmStreamingCompatible(true); |
1754 | |
1755 | if (func->hasFnAttribute(Kind: "aarch64_new_za" )) |
1756 | funcOp.setArmNewZa(true); |
1757 | else if (func->hasFnAttribute(Kind: "aarch64_in_za" )) |
1758 | funcOp.setArmInZa(true); |
1759 | else if (func->hasFnAttribute(Kind: "aarch64_out_za" )) |
1760 | funcOp.setArmOutZa(true); |
1761 | else if (func->hasFnAttribute(Kind: "aarch64_inout_za" )) |
1762 | funcOp.setArmInoutZa(true); |
1763 | else if (func->hasFnAttribute(Kind: "aarch64_preserves_za" )) |
1764 | funcOp.setArmPreservesZa(true); |
1765 | |
1766 | llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange); |
1767 | if (attr.isValid()) { |
1768 | MLIRContext *context = funcOp.getContext(); |
1769 | auto intTy = IntegerType::get(context, 32); |
1770 | funcOp.setVscaleRangeAttr(LLVM::VScaleRangeAttr::get( |
1771 | context, IntegerAttr::get(intTy, attr.getVScaleRangeMin()), |
1772 | IntegerAttr::get(intTy, attr.getVScaleRangeMax().value_or(0)))); |
1773 | } |
1774 | |
1775 | // Process frame-pointer attribute. |
1776 | if (func->hasFnAttribute(Kind: "frame-pointer" )) { |
1777 | StringRef stringRefFramePointerKind = |
1778 | func->getFnAttribute(Kind: "frame-pointer" ).getValueAsString(); |
1779 | funcOp.setFramePointerAttr(LLVM::FramePointerKindAttr::get( |
1780 | funcOp.getContext(), LLVM::framePointerKind::symbolizeFramePointerKind( |
1781 | stringRefFramePointerKind) |
1782 | .value())); |
1783 | } |
1784 | |
1785 | if (llvm::Attribute attr = func->getFnAttribute("target-cpu" ); |
1786 | attr.isStringAttribute()) |
1787 | funcOp.setTargetCpuAttr(StringAttr::get(context, attr.getValueAsString())); |
1788 | |
1789 | if (llvm::Attribute attr = func->getFnAttribute("target-features" ); |
1790 | attr.isStringAttribute()) |
1791 | funcOp.setTargetFeaturesAttr( |
1792 | LLVM::TargetFeaturesAttr::get(context, attr.getValueAsString())); |
1793 | |
1794 | if (llvm::Attribute attr = func->getFnAttribute(Kind: "unsafe-fp-math" ); |
1795 | attr.isStringAttribute()) |
1796 | funcOp.setUnsafeFpMath(attr.getValueAsBool()); |
1797 | |
1798 | if (llvm::Attribute attr = func->getFnAttribute(Kind: "no-infs-fp-math" ); |
1799 | attr.isStringAttribute()) |
1800 | funcOp.setNoInfsFpMath(attr.getValueAsBool()); |
1801 | |
1802 | if (llvm::Attribute attr = func->getFnAttribute(Kind: "no-nans-fp-math" ); |
1803 | attr.isStringAttribute()) |
1804 | funcOp.setNoNansFpMath(attr.getValueAsBool()); |
1805 | |
1806 | if (llvm::Attribute attr = func->getFnAttribute(Kind: "approx-func-fp-math" ); |
1807 | attr.isStringAttribute()) |
1808 | funcOp.setApproxFuncFpMath(attr.getValueAsBool()); |
1809 | |
1810 | if (llvm::Attribute attr = func->getFnAttribute(Kind: "no-signed-zeros-fp-math" ); |
1811 | attr.isStringAttribute()) |
1812 | funcOp.setNoSignedZerosFpMath(attr.getValueAsBool()); |
1813 | } |
1814 | |
1815 | DictionaryAttr |
1816 | ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, |
1817 | OpBuilder &builder) { |
1818 | SmallVector<NamedAttribute> paramAttrs; |
1819 | for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) { |
1820 | auto llvmAttr = llvmParamAttrs.getAttribute(Kind: llvmKind); |
1821 | // Skip attributes that are not attached. |
1822 | if (!llvmAttr.isValid()) |
1823 | continue; |
1824 | Attribute mlirAttr; |
1825 | if (llvmAttr.isTypeAttribute()) |
1826 | mlirAttr = TypeAttr::get(convertType(llvmAttr.getValueAsType())); |
1827 | else if (llvmAttr.isIntAttribute()) |
1828 | mlirAttr = builder.getI64IntegerAttr(llvmAttr.getValueAsInt()); |
1829 | else if (llvmAttr.isEnumAttribute()) |
1830 | mlirAttr = builder.getUnitAttr(); |
1831 | else |
1832 | llvm_unreachable("unexpected parameter attribute kind" ); |
1833 | paramAttrs.push_back(Elt: builder.getNamedAttr(name: mlirName, val: mlirAttr)); |
1834 | } |
1835 | |
1836 | return builder.getDictionaryAttr(paramAttrs); |
1837 | } |
1838 | |
1839 | void ModuleImport::convertParameterAttributes(llvm::Function *func, |
1840 | LLVMFuncOp funcOp, |
1841 | OpBuilder &builder) { |
1842 | auto llvmAttrs = func->getAttributes(); |
1843 | for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) { |
1844 | llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(ArgNo: i); |
1845 | funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder)); |
1846 | } |
1847 | // Convert the result attributes and attach them wrapped in an ArrayAttribute |
1848 | // to the funcOp. |
1849 | llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); |
1850 | if (!llvmResAttr.hasAttributes()) |
1851 | return; |
1852 | funcOp.setResAttrsAttr( |
1853 | builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder))); |
1854 | } |
1855 | |
1856 | LogicalResult ModuleImport::processFunction(llvm::Function *func) { |
1857 | clearRegionState(); |
1858 | |
1859 | auto functionType = |
1860 | dyn_cast<LLVMFunctionType>(convertType(func->getFunctionType())); |
1861 | if (func->isIntrinsic() && |
1862 | iface.isConvertibleIntrinsic(id: func->getIntrinsicID())) |
1863 | return success(); |
1864 | |
1865 | bool dsoLocal = func->hasLocalLinkage(); |
1866 | CConv cconv = convertCConvFromLLVM(func->getCallingConv()); |
1867 | |
1868 | // Insert the function at the end of the module. |
1869 | OpBuilder::InsertionGuard guard(builder); |
1870 | builder.setInsertionPoint(mlirModule.getBody(), mlirModule.getBody()->end()); |
1871 | |
1872 | Location loc = debugImporter->translateFuncLocation(func); |
1873 | LLVMFuncOp funcOp = builder.create<LLVMFuncOp>( |
1874 | loc, func->getName(), functionType, |
1875 | convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv); |
1876 | |
1877 | convertParameterAttributes(func, funcOp, builder); |
1878 | |
1879 | if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func)) |
1880 | funcOp.setPersonalityAttr(personality); |
1881 | else if (func->hasPersonalityFn()) |
1882 | emitWarning(funcOp.getLoc(), "could not deduce personality, skipping it" ); |
1883 | |
1884 | if (func->hasGC()) |
1885 | funcOp.setGarbageCollector(StringRef(func->getGC())); |
1886 | |
1887 | if (func->hasAtLeastLocalUnnamedAddr()) |
1888 | funcOp.setUnnamedAddr(convertUnnamedAddrFromLLVM(func->getUnnamedAddr())); |
1889 | |
1890 | if (func->hasSection()) |
1891 | funcOp.setSection(StringRef(func->getSection())); |
1892 | |
1893 | funcOp.setVisibility_(convertVisibilityFromLLVM(func->getVisibility())); |
1894 | |
1895 | if (func->hasComdat()) |
1896 | funcOp.setComdatAttr(comdatMapping.lookup(func->getComdat())); |
1897 | |
1898 | if (llvm::MaybeAlign maybeAlign = func->getAlign()) |
1899 | funcOp.setAlignment(maybeAlign->value()); |
1900 | |
1901 | // Handle Function attributes. |
1902 | processFunctionAttributes(func, funcOp); |
1903 | |
1904 | // Convert non-debug metadata by using the dialect interface. |
1905 | SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata; |
1906 | func->getAllMetadata(MDs&: allMetadata); |
1907 | for (auto &[kind, node] : allMetadata) { |
1908 | if (!iface.isConvertibleMetadata(kind)) |
1909 | continue; |
1910 | if (failed(iface.setMetadataAttrs(builder, kind, node, op: funcOp, moduleImport&: *this))) { |
1911 | emitWarning(funcOp.getLoc()) |
1912 | << "unhandled function metadata: " << diagMD(node, module: llvmModule.get()) |
1913 | << " on " << diag(value: *func); |
1914 | } |
1915 | } |
1916 | |
1917 | if (func->isDeclaration()) |
1918 | return success(); |
1919 | |
1920 | // Collect the set of basic blocks reachable from the function's entry block. |
1921 | // This step is crucial as LLVM IR can contain unreachable blocks that |
1922 | // self-dominate. As a result, an operation might utilize a variable it |
1923 | // defines, which the import does not support. Given that MLIR lacks block |
1924 | // label support, we can safely remove unreachable blocks, as there are no |
1925 | // indirect branch instructions that could potentially target these blocks. |
1926 | llvm::df_iterator_default_set<llvm::BasicBlock *> reachable; |
1927 | for (llvm::BasicBlock *basicBlock : llvm::depth_first_ext(G: func, S&: reachable)) |
1928 | (void)basicBlock; |
1929 | |
1930 | // Eagerly create all reachable blocks. |
1931 | SmallVector<llvm::BasicBlock *> reachableBasicBlocks; |
1932 | for (llvm::BasicBlock &basicBlock : *func) { |
1933 | // Skip unreachable blocks. |
1934 | if (!reachable.contains(Ptr: &basicBlock)) |
1935 | continue; |
1936 | Region &body = funcOp.getBody(); |
1937 | Block *block = builder.createBlock(parent: &body, insertPt: body.end()); |
1938 | mapBlock(llvm: &basicBlock, mlir: block); |
1939 | reachableBasicBlocks.push_back(&basicBlock); |
1940 | } |
1941 | |
1942 | // Add function arguments to the entry block. |
1943 | for (const auto &it : llvm::enumerate(First: func->args())) { |
1944 | BlockArgument blockArg = funcOp.getFunctionBody().addArgument( |
1945 | functionType.getParamType(it.index()), funcOp.getLoc()); |
1946 | mapValue(llvm: &it.value(), mlir: blockArg); |
1947 | } |
1948 | |
1949 | // Process the blocks in topological order. The ordered traversal ensures |
1950 | // operands defined in a dominating block have a valid mapping to an MLIR |
1951 | // value once a block is translated. |
1952 | SetVector<llvm::BasicBlock *> blocks = |
1953 | getTopologicallySortedBlocks(reachableBasicBlocks); |
1954 | setConstantInsertionPointToStart(lookupBlock(blocks.front())); |
1955 | for (llvm::BasicBlock *basicBlock : blocks) |
1956 | if (failed(processBasicBlock(basicBlock, lookupBlock(basicBlock)))) |
1957 | return failure(); |
1958 | |
1959 | // Process the debug intrinsics that require a delayed conversion after |
1960 | // everything else was converted. |
1961 | if (failed(result: processDebugIntrinsics())) |
1962 | return failure(); |
1963 | |
1964 | return success(); |
1965 | } |
1966 | |
1967 | /// Checks if `dbgIntr` is a kill location that holds metadata instead of an SSA |
1968 | /// value. |
1969 | static bool isMetadataKillLocation(llvm::DbgVariableIntrinsic *dbgIntr) { |
1970 | if (!dbgIntr->isKillLocation()) |
1971 | return false; |
1972 | llvm::Value *value = dbgIntr->getArgOperand(i: 0); |
1973 | auto *nodeAsVal = dyn_cast<llvm::MetadataAsValue>(Val: value); |
1974 | if (!nodeAsVal) |
1975 | return false; |
1976 | return !isa<llvm::ValueAsMetadata>(Val: nodeAsVal->getMetadata()); |
1977 | } |
1978 | |
1979 | LogicalResult |
1980 | ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, |
1981 | DominanceInfo &domInfo) { |
1982 | Location loc = translateLoc(loc: dbgIntr->getDebugLoc()); |
1983 | auto emitUnsupportedWarning = [&]() { |
1984 | if (emitExpensiveWarnings) |
1985 | emitWarning(loc) << "dropped intrinsic: " << diag(value: *dbgIntr); |
1986 | return success(); |
1987 | }; |
1988 | // Drop debug intrinsics with arg lists. |
1989 | // TODO: Support debug intrinsics that have arg lists. |
1990 | if (dbgIntr->hasArgList()) |
1991 | return emitUnsupportedWarning(); |
1992 | // Kill locations can have metadata nodes as location operand. This |
1993 | // cannot be converted to poison as the type cannot be reconstructed. |
1994 | // TODO: find a way to support this case. |
1995 | if (isMetadataKillLocation(dbgIntr)) |
1996 | return emitUnsupportedWarning(); |
1997 | // Drop debug intrinsics if the associated variable information cannot be |
1998 | // translated due to cyclic debug metadata. |
1999 | // TODO: Support cyclic debug metadata. |
2000 | DILocalVariableAttr localVariableAttr = |
2001 | matchLocalVariableAttr(dbgIntr->getArgOperand(1)); |
2002 | if (!localVariableAttr) |
2003 | return emitUnsupportedWarning(); |
2004 | FailureOr<Value> argOperand = convertMetadataValue(value: dbgIntr->getArgOperand(i: 0)); |
2005 | if (failed(result: argOperand)) |
2006 | return emitError(loc) << "failed to convert a debug intrinsic operand: " |
2007 | << diag(value: *dbgIntr); |
2008 | |
2009 | // Ensure that the debug instrinsic is inserted right after its operand is |
2010 | // defined. Otherwise, the operand might not necessarily dominate the |
2011 | // intrinsic. If the defining operation is a terminator, insert the intrinsic |
2012 | // into a dominated block. |
2013 | OpBuilder::InsertionGuard guard(builder); |
2014 | if (Operation *op = argOperand->getDefiningOp(); |
2015 | op && op->hasTrait<OpTrait::IsTerminator>()) { |
2016 | // Find a dominated block that can hold the debug intrinsic. |
2017 | auto dominatedBlocks = domInfo.getNode(a: op->getBlock())->children(); |
2018 | // If no block is dominated by the terminator, this intrinisc cannot be |
2019 | // converted. |
2020 | if (dominatedBlocks.empty()) |
2021 | return emitUnsupportedWarning(); |
2022 | // Set insertion point before the terminator, to avoid inserting something |
2023 | // before landingpads. |
2024 | Block *dominatedBlock = (*dominatedBlocks.begin())->getBlock(); |
2025 | builder.setInsertionPoint(dominatedBlock->getTerminator()); |
2026 | } else { |
2027 | builder.setInsertionPointAfterValue(*argOperand); |
2028 | } |
2029 | auto locationExprAttr = |
2030 | debugImporter->translateExpression(dbgIntr->getExpression()); |
2031 | Operation *op = |
2032 | llvm::TypeSwitch<llvm::DbgVariableIntrinsic *, Operation *>(dbgIntr) |
2033 | .Case(caseFn: [&](llvm::DbgDeclareInst *) { |
2034 | return builder.create<LLVM::DbgDeclareOp>( |
2035 | loc, *argOperand, localVariableAttr, locationExprAttr); |
2036 | }) |
2037 | .Case(caseFn: [&](llvm::DbgValueInst *) { |
2038 | return builder.create<LLVM::DbgValueOp>( |
2039 | loc, *argOperand, localVariableAttr, locationExprAttr); |
2040 | }); |
2041 | mapNoResultOp(llvm: dbgIntr, mlir: op); |
2042 | setNonDebugMetadataAttrs(inst: dbgIntr, op); |
2043 | return success(); |
2044 | } |
2045 | |
2046 | LogicalResult ModuleImport::processDebugIntrinsics() { |
2047 | DominanceInfo domInfo; |
2048 | for (llvm::Instruction *inst : debugIntrinsics) { |
2049 | auto *intrCall = cast<llvm::DbgVariableIntrinsic>(Val: inst); |
2050 | if (failed(result: processDebugIntrinsic(dbgIntr: intrCall, domInfo))) |
2051 | return failure(); |
2052 | } |
2053 | return success(); |
2054 | } |
2055 | |
2056 | LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb, |
2057 | Block *block) { |
2058 | builder.setInsertionPointToStart(block); |
2059 | for (llvm::Instruction &inst : *bb) { |
2060 | if (failed(result: processInstruction(inst: &inst))) |
2061 | return failure(); |
2062 | |
2063 | // Skip additional processing when the instructions is a debug intrinsics |
2064 | // that was not yet converted. |
2065 | if (debugIntrinsics.contains(key: &inst)) |
2066 | continue; |
2067 | |
2068 | // Set the non-debug metadata attributes on the imported operation and emit |
2069 | // a warning if an instruction other than a phi instruction is dropped |
2070 | // during the import. |
2071 | if (Operation *op = lookupOperation(inst: &inst)) { |
2072 | setNonDebugMetadataAttrs(inst: &inst, op); |
2073 | } else if (inst.getOpcode() != llvm::Instruction::PHI) { |
2074 | if (emitExpensiveWarnings) { |
2075 | Location loc = debugImporter->translateLoc(loc: inst.getDebugLoc()); |
2076 | emitWarning(loc) << "dropped instruction: " << diag(value: inst); |
2077 | } |
2078 | } |
2079 | } |
2080 | return success(); |
2081 | } |
2082 | |
2083 | FailureOr<SmallVector<AccessGroupAttr>> |
2084 | ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const { |
2085 | return loopAnnotationImporter->lookupAccessGroupAttrs(node); |
2086 | } |
2087 | |
2088 | LoopAnnotationAttr |
2089 | ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node, |
2090 | Location loc) const { |
2091 | return loopAnnotationImporter->translateLoopAnnotation(node, loc); |
2092 | } |
2093 | |
2094 | OwningOpRef<ModuleOp> |
2095 | mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule, |
2096 | MLIRContext *context, bool emitExpensiveWarnings, |
2097 | bool dropDICompositeTypeElements) { |
2098 | // Preload all registered dialects to allow the import to iterate the |
2099 | // registered LLVMImportDialectInterface implementations and query the |
2100 | // supported LLVM IR constructs before starting the translation. Assumes the |
2101 | // LLVM and DLTI dialects that convert the core LLVM IR constructs have been |
2102 | // registered before. |
2103 | assert(llvm::is_contained(context->getAvailableDialects(), |
2104 | LLVMDialect::getDialectNamespace())); |
2105 | assert(llvm::is_contained(context->getAvailableDialects(), |
2106 | DLTIDialect::getDialectNamespace())); |
2107 | context->loadAllAvailableDialects(); |
2108 | OwningOpRef<ModuleOp> module(ModuleOp::create(FileLineColLoc::get( |
2109 | StringAttr::get(context, llvmModule->getSourceFileName()), /*line=*/0, |
2110 | /*column=*/0))); |
2111 | |
2112 | ModuleImport moduleImport(module.get(), std::move(llvmModule), |
2113 | emitExpensiveWarnings, dropDICompositeTypeElements); |
2114 | if (failed(result: moduleImport.initializeImportInterface())) |
2115 | return {}; |
2116 | if (failed(result: moduleImport.convertDataLayout())) |
2117 | return {}; |
2118 | if (failed(result: moduleImport.convertComdats())) |
2119 | return {}; |
2120 | if (failed(result: moduleImport.convertMetadata())) |
2121 | return {}; |
2122 | if (failed(result: moduleImport.convertGlobals())) |
2123 | return {}; |
2124 | if (failed(result: moduleImport.convertFunctions())) |
2125 | return {}; |
2126 | |
2127 | return module; |
2128 | } |
2129 | |