1 | //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert MLIR Func and builtin dialects |
10 | // into the LLVM IR dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
15 | |
16 | #include "mlir/Analysis/DataLayoutAnalysis.h" |
17 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
18 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
19 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
20 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
21 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
22 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
23 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
24 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
25 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
26 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
27 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
28 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
29 | #include "mlir/IR/Attributes.h" |
30 | #include "mlir/IR/Builders.h" |
31 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
32 | #include "mlir/IR/BuiltinAttributes.h" |
33 | #include "mlir/IR/BuiltinOps.h" |
34 | #include "mlir/IR/IRMapping.h" |
35 | #include "mlir/IR/PatternMatch.h" |
36 | #include "mlir/IR/SymbolTable.h" |
37 | #include "mlir/IR/TypeUtilities.h" |
38 | #include "mlir/Transforms/DialectConversion.h" |
39 | #include "mlir/Transforms/Passes.h" |
40 | #include "llvm/ADT/SmallVector.h" |
41 | #include "llvm/ADT/TypeSwitch.h" |
42 | #include "llvm/IR/DerivedTypes.h" |
43 | #include "llvm/IR/IRBuilder.h" |
44 | #include "llvm/IR/Type.h" |
45 | #include "llvm/Support/Casting.h" |
46 | #include "llvm/Support/CommandLine.h" |
47 | #include "llvm/Support/FormatVariadic.h" |
48 | #include <algorithm> |
49 | #include <functional> |
50 | #include <optional> |
51 | |
52 | namespace mlir { |
53 | #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS |
54 | #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS |
55 | #include "mlir/Conversion/Passes.h.inc" |
56 | } // namespace mlir |
57 | |
58 | using namespace mlir; |
59 | |
60 | #define PASS_NAME "convert-func-to-llvm" |
61 | |
62 | static constexpr StringRef varargsAttrName = "func.varargs"; |
63 | static constexpr StringRef linkageAttrName = "llvm.linkage"; |
64 | static constexpr StringRef barePtrAttrName = "llvm.bareptr"; |
65 | |
66 | /// Return `true` if the `op` should use bare pointer calling convention. |
67 | static bool shouldUseBarePtrCallConv(Operation *op, |
68 | const LLVMTypeConverter *typeConverter) { |
69 | return (op && op->hasAttr(name: barePtrAttrName)) || |
70 | typeConverter->getOptions().useBarePtrCallConv; |
71 | } |
72 | |
73 | /// Only retain those attributes that are not constructed by |
74 | /// `LLVMFuncOp::build`. |
75 | static void filterFuncAttributes(FunctionOpInterface func, |
76 | SmallVectorImpl<NamedAttribute> &result) { |
77 | for (const NamedAttribute &attr : func->getDiscardableAttrs()) { |
78 | if (attr.getName() == linkageAttrName || |
79 | attr.getName() == varargsAttrName || |
80 | attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName()) |
81 | continue; |
82 | result.push_back(attr); |
83 | } |
84 | } |
85 | |
86 | /// Propagate argument/results attributes. |
87 | static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, |
88 | FunctionOpInterface funcOp, |
89 | LLVM::LLVMFuncOp wrapperFuncOp) { |
90 | auto argAttrs = funcOp.getAllArgAttrs(); |
91 | if (!resultStructType) { |
92 | if (auto resAttrs = funcOp.getAllResultAttrs()) |
93 | wrapperFuncOp.setAllResultAttrs(resAttrs); |
94 | if (argAttrs) |
95 | wrapperFuncOp.setAllArgAttrs(argAttrs); |
96 | } else { |
97 | SmallVector<Attribute> argAttributes; |
98 | // Only modify the argument and result attributes when the result is now |
99 | // an argument. |
100 | if (argAttrs) { |
101 | argAttributes.push_back(builder.getDictionaryAttr({})); |
102 | argAttributes.append(argAttrs.begin(), argAttrs.end()); |
103 | wrapperFuncOp.setAllArgAttrs(argAttributes); |
104 | } |
105 | } |
106 | cast<FunctionOpInterface>(wrapperFuncOp.getOperation()) |
107 | .setVisibility(funcOp.getVisibility()); |
108 | } |
109 | |
110 | /// Creates an auxiliary function with pointer-to-memref-descriptor-struct |
111 | /// arguments instead of unpacked arguments. This function can be called from C |
112 | /// by passing a pointer to a C struct corresponding to a memref descriptor. |
113 | /// Similarly, returned memrefs are passed via pointers to a C struct that is |
114 | /// passed as additional argument. |
115 | /// Internally, the auxiliary function unpacks the descriptor into individual |
116 | /// components and forwards them to `newFuncOp` and forwards the results to |
117 | /// the extra arguments. |
118 | static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, |
119 | const LLVMTypeConverter &typeConverter, |
120 | FunctionOpInterface funcOp, |
121 | LLVM::LLVMFuncOp newFuncOp) { |
122 | auto type = cast<FunctionType>(funcOp.getFunctionType()); |
123 | auto [wrapperFuncType, resultStructType] = |
124 | typeConverter.convertFunctionTypeCWrapper(type: type); |
125 | |
126 | SmallVector<NamedAttribute> attributes; |
127 | filterFuncAttributes(funcOp, attributes); |
128 | |
129 | auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>( |
130 | loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), |
131 | wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false, |
132 | /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); |
133 | propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp); |
134 | |
135 | OpBuilder::InsertionGuard guard(rewriter); |
136 | rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter)); |
137 | |
138 | SmallVector<Value, 8> args; |
139 | size_t argOffset = resultStructType ? 1 : 0; |
140 | for (auto [index, argType] : llvm::enumerate(type.getInputs())) { |
141 | Value arg = wrapperFuncOp.getArgument(index + argOffset); |
142 | if (auto memrefType = dyn_cast<MemRefType>(argType)) { |
143 | Value loaded = rewriter.create<LLVM::LoadOp>( |
144 | loc, typeConverter.convertType(memrefType), arg); |
145 | MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); |
146 | continue; |
147 | } |
148 | if (isa<UnrankedMemRefType>(argType)) { |
149 | Value loaded = rewriter.create<LLVM::LoadOp>( |
150 | loc, typeConverter.convertType(argType), arg); |
151 | UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); |
152 | continue; |
153 | } |
154 | |
155 | args.push_back(arg); |
156 | } |
157 | |
158 | auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args); |
159 | |
160 | if (resultStructType) { |
161 | rewriter.create<LLVM::StoreOp>(loc, call.getResult(), |
162 | wrapperFuncOp.getArgument(0)); |
163 | rewriter.create<LLVM::ReturnOp>(loc, ValueRange{}); |
164 | } else { |
165 | rewriter.create<LLVM::ReturnOp>(loc, call.getResults()); |
166 | } |
167 | } |
168 | |
169 | /// Creates an auxiliary function with pointer-to-memref-descriptor-struct |
170 | /// arguments instead of unpacked arguments. Creates a body for the (external) |
171 | /// `newFuncOp` that allocates a memref descriptor on stack, packs the |
172 | /// individual arguments into this descriptor and passes a pointer to it into |
173 | /// the auxiliary function. If the result of the function cannot be directly |
174 | /// returned, we write it to a special first argument that provides a pointer |
175 | /// to a corresponding struct. This auxiliary external function is now |
176 | /// compatible with functions defined in C using pointers to C structs |
177 | /// corresponding to a memref descriptor. |
178 | static void wrapExternalFunction(OpBuilder &builder, Location loc, |
179 | const LLVMTypeConverter &typeConverter, |
180 | FunctionOpInterface funcOp, |
181 | LLVM::LLVMFuncOp newFuncOp) { |
182 | OpBuilder::InsertionGuard guard(builder); |
183 | |
184 | auto [wrapperType, resultStructType] = |
185 | typeConverter.convertFunctionTypeCWrapper( |
186 | type: cast<FunctionType>(funcOp.getFunctionType())); |
187 | // This conversion can only fail if it could not convert one of the argument |
188 | // types. But since it has been applied to a non-wrapper function before, it |
189 | // should have failed earlier and not reach this point at all. |
190 | assert(wrapperType && "unexpected type conversion failure"); |
191 | |
192 | SmallVector<NamedAttribute, 4> attributes; |
193 | filterFuncAttributes(funcOp, attributes); |
194 | |
195 | // Create the auxiliary function. |
196 | auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>( |
197 | loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), |
198 | wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false, |
199 | /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); |
200 | propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc); |
201 | |
202 | // The wrapper that we synthetize here should only be visible in this module. |
203 | newFuncOp.setLinkage(LLVM::Linkage::Private); |
204 | builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder)); |
205 | |
206 | // Get a ValueRange containing arguments. |
207 | FunctionType type = cast<FunctionType>(funcOp.getFunctionType()); |
208 | SmallVector<Value, 8> args; |
209 | args.reserve(N: type.getNumInputs()); |
210 | ValueRange wrapperArgsRange(newFuncOp.getArguments()); |
211 | |
212 | if (resultStructType) { |
213 | // Allocate the struct on the stack and pass the pointer. |
214 | Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0); |
215 | Value one = builder.create<LLVM::ConstantOp>( |
216 | loc, typeConverter.convertType(builder.getIndexType()), |
217 | builder.getIntegerAttr(builder.getIndexType(), 1)); |
218 | Value result = |
219 | builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one); |
220 | args.push_back(Elt: result); |
221 | } |
222 | |
223 | // Iterate over the inputs of the original function and pack values into |
224 | // memref descriptors if the original type is a memref. |
225 | for (Type input : type.getInputs()) { |
226 | Value arg; |
227 | int numToDrop = 1; |
228 | auto memRefType = dyn_cast<MemRefType>(input); |
229 | auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input); |
230 | if (memRefType || unrankedMemRefType) { |
231 | numToDrop = memRefType |
232 | ? MemRefDescriptor::getNumUnpackedValues(memRefType) |
233 | : UnrankedMemRefDescriptor::getNumUnpackedValues(); |
234 | Value packed = |
235 | memRefType |
236 | ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, |
237 | wrapperArgsRange.take_front(numToDrop)) |
238 | : UnrankedMemRefDescriptor::pack( |
239 | builder, loc, typeConverter, unrankedMemRefType, |
240 | wrapperArgsRange.take_front(numToDrop)); |
241 | |
242 | auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); |
243 | Value one = builder.create<LLVM::ConstantOp>( |
244 | loc, typeConverter.convertType(builder.getIndexType()), |
245 | builder.getIntegerAttr(builder.getIndexType(), 1)); |
246 | Value allocated = builder.create<LLVM::AllocaOp>( |
247 | loc, ptrTy, packed.getType(), one, /*alignment=*/0); |
248 | builder.create<LLVM::StoreOp>(loc, packed, allocated); |
249 | arg = allocated; |
250 | } else { |
251 | arg = wrapperArgsRange[0]; |
252 | } |
253 | |
254 | args.push_back(arg); |
255 | wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); |
256 | } |
257 | assert(wrapperArgsRange.empty() && "did not map some of the arguments"); |
258 | |
259 | auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args); |
260 | |
261 | if (resultStructType) { |
262 | Value result = |
263 | builder.create<LLVM::LoadOp>(loc, resultStructType, args.front()); |
264 | builder.create<LLVM::ReturnOp>(loc, result); |
265 | } else { |
266 | builder.create<LLVM::ReturnOp>(loc, call.getResults()); |
267 | } |
268 | } |
269 | |
270 | /// Inserts `llvm.load` ops in the function body to restore the expected pointee |
271 | /// value from `llvm.byval`/`llvm.byref` function arguments that were converted |
272 | /// to LLVM pointer types. |
273 | static void restoreByValRefArgumentType( |
274 | ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, |
275 | ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs, |
276 | ArrayRef<BlockArgument> oldBlockArgs, LLVM::LLVMFuncOp funcOp) { |
277 | // Nothing to do for function declarations. |
278 | if (funcOp.isExternal()) |
279 | return; |
280 | |
281 | ConversionPatternRewriter::InsertionGuard guard(rewriter); |
282 | rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front()); |
283 | |
284 | for (const auto &[arg, oldArg, byValRefAttr] : |
285 | llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) { |
286 | // Skip argument if no `llvm.byval` or `llvm.byref` attribute. |
287 | if (!byValRefAttr) |
288 | continue; |
289 | |
290 | // Insert load to retrieve the actual argument passed by value/reference. |
291 | assert(isa<LLVM::LLVMPointerType>(arg.getType()) && |
292 | "Expected LLVM pointer type for argument with " |
293 | "`llvm.byval`/`llvm.byref` attribute"); |
294 | Type resTy = typeConverter.convertType( |
295 | cast<TypeAttr>(byValRefAttr->getValue()).getValue()); |
296 | |
297 | auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg); |
298 | rewriter.replaceUsesOfBlockArgument(oldArg, valueArg); |
299 | } |
300 | } |
301 | |
302 | FailureOr<LLVM::LLVMFuncOp> |
303 | mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, |
304 | ConversionPatternRewriter &rewriter, |
305 | const LLVMTypeConverter &converter) { |
306 | // Check the funcOp has `FunctionType`. |
307 | auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType()); |
308 | if (!funcTy) |
309 | return rewriter.notifyMatchFailure( |
310 | funcOp, "Only support FunctionOpInterface with FunctionType"); |
311 | |
312 | // Keep track of the entry block arguments. They will be needed later. |
313 | SmallVector<BlockArgument> oldBlockArgs = |
314 | llvm::to_vector(funcOp.getArguments()); |
315 | |
316 | // Convert the original function arguments. They are converted using the |
317 | // LLVMTypeConverter provided to this legalization pattern. |
318 | auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName); |
319 | // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was |
320 | // overriden with an LLVM pointer type for later processing. |
321 | SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs; |
322 | TypeConverter::SignatureConversion result(funcOp.getNumArguments()); |
323 | auto llvmType = dyn_cast_or_null<LLVM::LLVMFunctionType>( |
324 | converter.convertFunctionSignature( |
325 | funcOp, varargsAttr && varargsAttr.getValue(), |
326 | shouldUseBarePtrCallConv(funcOp, &converter), result, |
327 | byValRefNonPtrAttrs)); |
328 | if (!llvmType) |
329 | return rewriter.notifyMatchFailure(funcOp, "signature conversion failed"); |
330 | |
331 | // Check for unsupported variadic functions. |
332 | if (!shouldUseBarePtrCallConv(funcOp, &converter)) |
333 | if (funcOp->getAttrOfType<UnitAttr>( |
334 | LLVM::LLVMDialect::getEmitCWrapperAttrName())) |
335 | if (llvmType.isVarArg()) |
336 | return funcOp.emitError("C interface for variadic functions is not " |
337 | "supported yet."); |
338 | |
339 | // Create an LLVM function, use external linkage by default until MLIR |
340 | // functions have linkage. |
341 | LLVM::Linkage linkage = LLVM::Linkage::External; |
342 | if (funcOp->hasAttr(linkageAttrName)) { |
343 | auto attr = |
344 | dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName)); |
345 | if (!attr) { |
346 | funcOp->emitError() << "Contains "<< linkageAttrName |
347 | << " attribute not of type LLVM::LinkageAttr"; |
348 | return rewriter.notifyMatchFailure( |
349 | funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr"); |
350 | } |
351 | linkage = attr.getLinkage(); |
352 | } |
353 | |
354 | // Check for invalid attributes. |
355 | StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName(); |
356 | if (funcOp->hasAttr(readnoneAttrName)) { |
357 | auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName); |
358 | if (!attr) { |
359 | funcOp->emitError() << "Contains "<< readnoneAttrName |
360 | << " attribute not of type UnitAttr"; |
361 | return rewriter.notifyMatchFailure( |
362 | funcOp, "Contains readnone attribute not of type UnitAttr"); |
363 | } |
364 | } |
365 | |
366 | SmallVector<NamedAttribute, 4> attributes; |
367 | filterFuncAttributes(funcOp, attributes); |
368 | auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( |
369 | funcOp.getLoc(), funcOp.getName(), llvmType, linkage, |
370 | /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, |
371 | attributes); |
372 | cast<FunctionOpInterface>(newFuncOp.getOperation()) |
373 | .setVisibility(funcOp.getVisibility()); |
374 | |
375 | // Create a memory effect attribute corresponding to readnone. |
376 | if (funcOp->hasAttr(readnoneAttrName)) { |
377 | auto memoryAttr = LLVM::MemoryEffectsAttr::get( |
378 | rewriter.getContext(), |
379 | {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef, |
380 | LLVM::ModRefInfo::NoModRef}); |
381 | newFuncOp.setMemoryEffectsAttr(memoryAttr); |
382 | } |
383 | |
384 | // Propagate argument/result attributes to all converted arguments/result |
385 | // obtained after converting a given original argument/result. |
386 | if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { |
387 | assert(!resAttrDicts.empty() && "expected array to be non-empty"); |
388 | if (funcOp.getNumResults() == 1) |
389 | newFuncOp.setAllResultAttrs(resAttrDicts); |
390 | } |
391 | if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { |
392 | SmallVector<Attribute> newArgAttrs( |
393 | cast<LLVM::LLVMFunctionType>(llvmType).getNumParams()); |
394 | for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { |
395 | // Some LLVM IR attribute have a type attached to them. During FuncOp -> |
396 | // LLVMFuncOp conversion these types may have changed. Account for that |
397 | // change by converting attributes' types as well. |
398 | SmallVector<NamedAttribute, 4> convertedAttrs; |
399 | auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]); |
400 | convertedAttrs.reserve(N: attrsDict.size()); |
401 | for (const NamedAttribute &attr : attrsDict) { |
402 | const auto convert = [&](const NamedAttribute &attr) { |
403 | return TypeAttr::get(converter.convertType( |
404 | cast<TypeAttr>(attr.getValue()).getValue())); |
405 | }; |
406 | if (attr.getName().getValue() == |
407 | LLVM::LLVMDialect::getByValAttrName()) { |
408 | convertedAttrs.push_back(rewriter.getNamedAttr( |
409 | LLVM::LLVMDialect::getByValAttrName(), convert(attr))); |
410 | } else if (attr.getName().getValue() == |
411 | LLVM::LLVMDialect::getByRefAttrName()) { |
412 | convertedAttrs.push_back(rewriter.getNamedAttr( |
413 | LLVM::LLVMDialect::getByRefAttrName(), convert(attr))); |
414 | } else if (attr.getName().getValue() == |
415 | LLVM::LLVMDialect::getStructRetAttrName()) { |
416 | convertedAttrs.push_back(rewriter.getNamedAttr( |
417 | LLVM::LLVMDialect::getStructRetAttrName(), convert(attr))); |
418 | } else if (attr.getName().getValue() == |
419 | LLVM::LLVMDialect::getInAllocaAttrName()) { |
420 | convertedAttrs.push_back(rewriter.getNamedAttr( |
421 | LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr))); |
422 | } else { |
423 | convertedAttrs.push_back(attr); |
424 | } |
425 | } |
426 | auto mapping = result.getInputMapping(input: i); |
427 | assert(mapping && "unexpected deletion of function argument"); |
428 | // Only attach the new argument attributes if there is a one-to-one |
429 | // mapping from old to new types. Otherwise, attributes might be |
430 | // attached to types that they do not support. |
431 | if (mapping->size == 1) { |
432 | newArgAttrs[mapping->inputNo] = |
433 | DictionaryAttr::get(rewriter.getContext(), convertedAttrs); |
434 | continue; |
435 | } |
436 | // TODO: Implement custom handling for types that expand to multiple |
437 | // function arguments. |
438 | for (size_t j = 0; j < mapping->size; ++j) |
439 | newArgAttrs[mapping->inputNo + j] = |
440 | DictionaryAttr::get(rewriter.getContext(), {}); |
441 | } |
442 | if (!newArgAttrs.empty()) |
443 | newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs)); |
444 | } |
445 | |
446 | rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(), |
447 | newFuncOp.end()); |
448 | // Convert just the entry block. The remaining unstructured control flow is |
449 | // converted by ControlFlowToLLVM. |
450 | if (!newFuncOp.getBody().empty()) |
451 | rewriter.applySignatureConversion(block: &newFuncOp.getBody().front(), conversion&: result, |
452 | converter: &converter); |
453 | |
454 | // Fix the type mismatch between the materialized `llvm.ptr` and the expected |
455 | // pointee type in the function body when converting `llvm.byval`/`llvm.byref` |
456 | // function arguments. |
457 | restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs, |
458 | oldBlockArgs, newFuncOp); |
459 | |
460 | if (!shouldUseBarePtrCallConv(funcOp, &converter)) { |
461 | if (funcOp->getAttrOfType<UnitAttr>( |
462 | LLVM::LLVMDialect::getEmitCWrapperAttrName())) { |
463 | if (newFuncOp.isExternal()) |
464 | wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp, |
465 | newFuncOp); |
466 | else |
467 | wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp, |
468 | newFuncOp); |
469 | } |
470 | } |
471 | |
472 | return newFuncOp; |
473 | } |
474 | |
475 | namespace { |
476 | |
477 | /// FuncOp legalization pattern that converts MemRef arguments to pointers to |
478 | /// MemRef descriptors (LLVM struct data types) containing all the MemRef type |
479 | /// information. |
480 | struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> { |
481 | FuncOpConversion(const LLVMTypeConverter &converter) |
482 | : ConvertOpToLLVMPattern(converter) {} |
483 | |
484 | LogicalResult |
485 | matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, |
486 | ConversionPatternRewriter &rewriter) const override { |
487 | FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp( |
488 | cast<FunctionOpInterface>(funcOp.getOperation()), rewriter, |
489 | *getTypeConverter()); |
490 | if (failed(Result: newFuncOp)) |
491 | return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop"); |
492 | |
493 | rewriter.eraseOp(op: funcOp); |
494 | return success(); |
495 | } |
496 | }; |
497 | |
498 | struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> { |
499 | using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern; |
500 | |
501 | LogicalResult |
502 | matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor, |
503 | ConversionPatternRewriter &rewriter) const override { |
504 | auto type = typeConverter->convertType(op.getResult().getType()); |
505 | if (!type || !LLVM::isCompatibleType(type: type)) |
506 | return rewriter.notifyMatchFailure(op, "failed to convert result type"); |
507 | |
508 | auto newOp = |
509 | rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue()); |
510 | for (const NamedAttribute &attr : op->getAttrs()) { |
511 | if (attr.getName().strref() == "value") |
512 | continue; |
513 | newOp->setAttr(attr.getName(), attr.getValue()); |
514 | } |
515 | rewriter.replaceOp(op, newOp->getResults()); |
516 | return success(); |
517 | } |
518 | }; |
519 | |
520 | // A CallOp automatically promotes MemRefType to a sequence of alloca/store and |
521 | // passes the pointer to the MemRef across function boundaries. |
522 | template <typename CallOpType> |
523 | struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { |
524 | using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern; |
525 | using Super = CallOpInterfaceLowering<CallOpType>; |
526 | using Base = ConvertOpToLLVMPattern<CallOpType>; |
527 | |
528 | LogicalResult matchAndRewriteImpl(CallOpType callOp, |
529 | typename CallOpType::Adaptor adaptor, |
530 | ConversionPatternRewriter &rewriter, |
531 | bool useBarePtrCallConv = false) const { |
532 | // Pack the result types into a struct. |
533 | Type packedResult = nullptr; |
534 | unsigned numResults = callOp.getNumResults(); |
535 | auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); |
536 | |
537 | if (numResults != 0) { |
538 | if (!(packedResult = this->getTypeConverter()->packFunctionResults( |
539 | resultTypes, useBarePtrCallConv))) |
540 | return failure(); |
541 | } |
542 | |
543 | if (useBarePtrCallConv) { |
544 | for (auto it : callOp->getOperands()) { |
545 | Type operandType = it.getType(); |
546 | if (isa<UnrankedMemRefType>(Val: operandType)) { |
547 | // Unranked memref is not supported in the bare pointer calling |
548 | // convention. |
549 | return failure(); |
550 | } |
551 | } |
552 | } |
553 | auto promoted = this->getTypeConverter()->promoteOperands( |
554 | callOp.getLoc(), /*opOperands=*/callOp->getOperands(), |
555 | adaptor.getOperands(), rewriter, useBarePtrCallConv); |
556 | auto newOp = rewriter.create<LLVM::CallOp>( |
557 | callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), |
558 | promoted, callOp->getAttrs()); |
559 | |
560 | newOp.getProperties().operandSegmentSizes = { |
561 | static_cast<int32_t>(promoted.size()), 0}; |
562 | newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); |
563 | |
564 | SmallVector<Value, 4> results; |
565 | if (numResults < 2) { |
566 | // If < 2 results, packing did not do anything and we can just return. |
567 | results.append(newOp.result_begin(), newOp.result_end()); |
568 | } else { |
569 | // Otherwise, it had been converted to an operation producing a structure. |
570 | // Extract individual results from the structure and return them as list. |
571 | results.reserve(N: numResults); |
572 | for (unsigned i = 0; i < numResults; ++i) { |
573 | results.push_back(rewriter.create<LLVM::ExtractValueOp>( |
574 | callOp.getLoc(), newOp->getResult(0), i)); |
575 | } |
576 | } |
577 | |
578 | if (useBarePtrCallConv) { |
579 | // For the bare-ptr calling convention, promote memref results to |
580 | // descriptors. |
581 | assert(results.size() == resultTypes.size() && |
582 | "The number of arguments and types doesn't match"); |
583 | this->getTypeConverter()->promoteBarePtrsToDescriptors( |
584 | rewriter, callOp.getLoc(), resultTypes, results); |
585 | } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), |
586 | resultTypes, results, |
587 | /*toDynamic=*/false))) { |
588 | return failure(); |
589 | } |
590 | |
591 | rewriter.replaceOp(callOp, results); |
592 | return success(); |
593 | } |
594 | }; |
595 | |
596 | class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> { |
597 | public: |
598 | CallOpLowering(const LLVMTypeConverter &typeConverter, |
599 | // Can be nullptr. |
600 | const SymbolTable *symbolTable, PatternBenefit benefit = 1) |
601 | : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit), |
602 | symbolTable(symbolTable) {} |
603 | |
604 | LogicalResult |
605 | matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, |
606 | ConversionPatternRewriter &rewriter) const override { |
607 | bool useBarePtrCallConv = false; |
608 | if (getTypeConverter()->getOptions().useBarePtrCallConv) { |
609 | useBarePtrCallConv = true; |
610 | } else if (symbolTable != nullptr) { |
611 | // Fast lookup. |
612 | Operation *callee = |
613 | symbolTable->lookup(callOp.getCalleeAttr().getValue()); |
614 | useBarePtrCallConv = |
615 | callee != nullptr && callee->hasAttr(name: barePtrAttrName); |
616 | } else { |
617 | // Warning: This is a linear lookup. |
618 | Operation *callee = |
619 | SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr()); |
620 | useBarePtrCallConv = |
621 | callee != nullptr && callee->hasAttr(name: barePtrAttrName); |
622 | } |
623 | return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv); |
624 | } |
625 | |
626 | private: |
627 | const SymbolTable *symbolTable = nullptr; |
628 | }; |
629 | |
630 | struct CallIndirectOpLowering |
631 | : public CallOpInterfaceLowering<func::CallIndirectOp> { |
632 | using Super::Super; |
633 | |
634 | LogicalResult |
635 | matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor, |
636 | ConversionPatternRewriter &rewriter) const override { |
637 | return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter); |
638 | } |
639 | }; |
640 | |
641 | struct UnrealizedConversionCastOpLowering |
642 | : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> { |
643 | using ConvertOpToLLVMPattern< |
644 | UnrealizedConversionCastOp>::ConvertOpToLLVMPattern; |
645 | |
646 | LogicalResult |
647 | matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, |
648 | ConversionPatternRewriter &rewriter) const override { |
649 | SmallVector<Type> convertedTypes; |
650 | if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(), |
651 | convertedTypes)) && |
652 | convertedTypes == adaptor.getInputs().getTypes()) { |
653 | rewriter.replaceOp(op, adaptor.getInputs()); |
654 | return success(); |
655 | } |
656 | |
657 | convertedTypes.clear(); |
658 | if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(), |
659 | convertedTypes)) && |
660 | convertedTypes == op.getOutputs().getType()) { |
661 | rewriter.replaceOp(op, adaptor.getInputs()); |
662 | return success(); |
663 | } |
664 | return failure(); |
665 | } |
666 | }; |
667 | |
668 | // Special lowering pattern for `ReturnOps`. Unlike all other operations, |
669 | // `ReturnOp` interacts with the function signature and must have as many |
670 | // operands as the function has return values. Because in LLVM IR, functions |
671 | // can only return 0 or 1 value, we pack multiple values into a structure type. |
672 | // Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if |
673 | // necessary before returning it |
674 | struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { |
675 | using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; |
676 | |
677 | LogicalResult |
678 | matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, |
679 | ConversionPatternRewriter &rewriter) const override { |
680 | Location loc = op.getLoc(); |
681 | unsigned numArguments = op.getNumOperands(); |
682 | SmallVector<Value, 4> updatedOperands; |
683 | |
684 | auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); |
685 | bool useBarePtrCallConv = |
686 | shouldUseBarePtrCallConv(funcOp, this->getTypeConverter()); |
687 | if (useBarePtrCallConv) { |
688 | // For the bare-ptr calling convention, extract the aligned pointer to |
689 | // be returned from the memref descriptor. |
690 | for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { |
691 | Type oldTy = std::get<0>(it).getType(); |
692 | Value newOperand = std::get<1>(it); |
693 | if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr( |
694 | cast<BaseMemRefType>(oldTy))) { |
695 | MemRefDescriptor memrefDesc(newOperand); |
696 | newOperand = memrefDesc.allocatedPtr(rewriter, loc); |
697 | } else if (isa<UnrankedMemRefType>(oldTy)) { |
698 | // Unranked memref is not supported in the bare pointer calling |
699 | // convention. |
700 | return failure(); |
701 | } |
702 | updatedOperands.push_back(newOperand); |
703 | } |
704 | } else { |
705 | updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); |
706 | (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), |
707 | updatedOperands, |
708 | /*toDynamic=*/true); |
709 | } |
710 | |
711 | // If ReturnOp has 0 or 1 operand, create it and return immediately. |
712 | if (numArguments <= 1) { |
713 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( |
714 | op, TypeRange(), updatedOperands, op->getAttrs()); |
715 | return success(); |
716 | } |
717 | |
718 | // Otherwise, we need to pack the arguments into an LLVM struct type before |
719 | // returning. |
720 | auto packedType = getTypeConverter()->packFunctionResults( |
721 | op.getOperandTypes(), useBarePtrCallConv); |
722 | if (!packedType) { |
723 | return rewriter.notifyMatchFailure(op, "could not convert result types"); |
724 | } |
725 | |
726 | Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType); |
727 | for (auto [idx, operand] : llvm::enumerate(First&: updatedOperands)) { |
728 | packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); |
729 | } |
730 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, |
731 | op->getAttrs()); |
732 | return success(); |
733 | } |
734 | }; |
735 | } // namespace |
736 | |
737 | void mlir::populateFuncToLLVMFuncOpConversionPattern( |
738 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
739 | patterns.add<FuncOpConversion>(arg: converter); |
740 | } |
741 | |
742 | void mlir::populateFuncToLLVMConversionPatterns( |
743 | const LLVMTypeConverter &converter, RewritePatternSet &patterns, |
744 | const SymbolTable *symbolTable) { |
745 | populateFuncToLLVMFuncOpConversionPattern(converter, patterns); |
746 | patterns.add<CallIndirectOpLowering>(arg: converter); |
747 | patterns.add<CallOpLowering>(arg: converter, args&: symbolTable); |
748 | patterns.add<ConstantOpLowering>(arg: converter); |
749 | patterns.add<ReturnOpLowering>(arg: converter); |
750 | } |
751 | |
752 | namespace { |
753 | /// A pass converting Func operations into the LLVM IR dialect. |
754 | struct ConvertFuncToLLVMPass |
755 | : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> { |
756 | using Base::Base; |
757 | |
758 | /// Run the dialect converter on the module. |
759 | void runOnOperation() override { |
760 | ModuleOp m = getOperation(); |
761 | StringRef dataLayout; |
762 | auto dataLayoutAttr = dyn_cast_or_null<StringAttr>( |
763 | m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())); |
764 | if (dataLayoutAttr) |
765 | dataLayout = dataLayoutAttr.getValue(); |
766 | |
767 | if (failed(LLVM::LLVMDialect::verifyDataLayoutString( |
768 | dataLayout, [this](const Twine &message) { |
769 | getOperation().emitError() << message.str(); |
770 | }))) { |
771 | signalPassFailure(); |
772 | return; |
773 | } |
774 | |
775 | const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); |
776 | |
777 | LowerToLLVMOptions options(&getContext(), |
778 | dataLayoutAnalysis.getAtOrAbove(m)); |
779 | options.useBarePtrCallConv = useBarePtrCallConv; |
780 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
781 | options.overrideIndexBitwidth(indexBitwidth); |
782 | options.dataLayout = llvm::DataLayout(dataLayout); |
783 | |
784 | LLVMTypeConverter typeConverter(&getContext(), options, |
785 | &dataLayoutAnalysis); |
786 | |
787 | std::optional<SymbolTable> optSymbolTable = std::nullopt; |
788 | const SymbolTable *symbolTable = nullptr; |
789 | if (!options.useBarePtrCallConv) { |
790 | optSymbolTable.emplace(m); |
791 | symbolTable = &optSymbolTable.value(); |
792 | } |
793 | |
794 | RewritePatternSet patterns(&getContext()); |
795 | populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns, symbolTable); |
796 | |
797 | LLVMConversionTarget target(getContext()); |
798 | if (failed(applyPartialConversion(m, target, std::move(patterns)))) |
799 | signalPassFailure(); |
800 | } |
801 | }; |
802 | |
803 | struct SetLLVMModuleDataLayoutPass |
804 | : public impl::SetLLVMModuleDataLayoutPassBase< |
805 | SetLLVMModuleDataLayoutPass> { |
806 | using Base::Base; |
807 | |
808 | /// Run the dialect converter on the module. |
809 | void runOnOperation() override { |
810 | if (failed(LLVM::LLVMDialect::verifyDataLayoutString( |
811 | this->dataLayout, [this](const Twine &message) { |
812 | getOperation().emitError() << message.str(); |
813 | }))) { |
814 | signalPassFailure(); |
815 | return; |
816 | } |
817 | ModuleOp m = getOperation(); |
818 | m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), |
819 | StringAttr::get(m.getContext(), this->dataLayout)); |
820 | } |
821 | }; |
822 | } // namespace |
823 | |
824 | //===----------------------------------------------------------------------===// |
825 | // ConvertToLLVMPatternInterface implementation |
826 | //===----------------------------------------------------------------------===// |
827 | |
828 | namespace { |
829 | /// Implement the interface to convert Func to LLVM. |
830 | struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
831 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
832 | /// Hook for derived dialect interface to provide conversion patterns |
833 | /// and mark dialect legal for the conversion target. |
834 | void populateConvertToLLVMConversionPatterns( |
835 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
836 | RewritePatternSet &patterns) const final { |
837 | populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns); |
838 | } |
839 | }; |
840 | } // namespace |
841 | |
842 | void mlir::registerConvertFuncToLLVMInterface(DialectRegistry ®istry) { |
843 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, func::FuncDialect *dialect) { |
844 | dialect->addInterfaces<FuncToLLVMDialectInterface>(); |
845 | }); |
846 | } |
847 |
Definitions
- varargsAttrName
- linkageAttrName
- barePtrAttrName
- shouldUseBarePtrCallConv
- filterFuncAttributes
- propagateArgResAttrs
- wrapForExternalCallers
- wrapExternalFunction
- restoreByValRefArgumentType
- convertFuncOpToLLVMFuncOp
- FuncOpConversion
- FuncOpConversion
- matchAndRewrite
- ConstantOpLowering
- matchAndRewrite
- CallOpInterfaceLowering
- matchAndRewriteImpl
- CallOpLowering
- CallOpLowering
- matchAndRewrite
- CallIndirectOpLowering
- matchAndRewrite
- UnrealizedConversionCastOpLowering
- matchAndRewrite
- ReturnOpLowering
- matchAndRewrite
- populateFuncToLLVMFuncOpConversionPattern
- populateFuncToLLVMConversionPatterns
- ConvertFuncToLLVMPass
- runOnOperation
- SetLLVMModuleDataLayoutPass
- runOnOperation
- FuncToLLVMDialectInterface
- populateConvertToLLVMConversionPatterns
Improve your Profiling and Debugging skills
Find out more