1//===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR Func and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
15
16#include "mlir/Analysis/DataLayoutAnalysis.h"
17#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
18#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
19#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
22#include "mlir/Conversion/LLVMCommon/Pattern.h"
23#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h"
25#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
26#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
28#include "mlir/Dialect/Utils/StaticValueUtils.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinAttributeInterfaces.h"
32#include "mlir/IR/BuiltinAttributes.h"
33#include "mlir/IR/BuiltinOps.h"
34#include "mlir/IR/IRMapping.h"
35#include "mlir/IR/PatternMatch.h"
36#include "mlir/IR/SymbolTable.h"
37#include "mlir/IR/TypeUtilities.h"
38#include "mlir/Transforms/DialectConversion.h"
39#include "mlir/Transforms/Passes.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/TypeSwitch.h"
42#include "llvm/IR/DerivedTypes.h"
43#include "llvm/IR/IRBuilder.h"
44#include "llvm/IR/Type.h"
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/CommandLine.h"
47#include "llvm/Support/FormatVariadic.h"
48#include <algorithm>
49#include <functional>
50#include <optional>
51
52namespace mlir {
53#define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54#define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55#include "mlir/Conversion/Passes.h.inc"
56} // namespace mlir
57
58using namespace mlir;
59
60#define PASS_NAME "convert-func-to-llvm"
61
62static constexpr StringRef varargsAttrName = "func.varargs";
63static constexpr StringRef linkageAttrName = "llvm.linkage";
64static constexpr StringRef barePtrAttrName = "llvm.bareptr";
65
66/// Return `true` if the `op` should use bare pointer calling convention.
67static bool shouldUseBarePtrCallConv(Operation *op,
68 const LLVMTypeConverter *typeConverter) {
69 return (op && op->hasAttr(name: barePtrAttrName)) ||
70 typeConverter->getOptions().useBarePtrCallConv;
71}
72
73/// Only retain those attributes that are not constructed by
74/// `LLVMFuncOp::build`.
75static void filterFuncAttributes(FunctionOpInterface func,
76 SmallVectorImpl<NamedAttribute> &result) {
77 for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
78 if (attr.getName() == linkageAttrName ||
79 attr.getName() == varargsAttrName ||
80 attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
81 continue;
82 result.push_back(attr);
83 }
84}
85
86/// Propagate argument/results attributes.
87static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
88 FunctionOpInterface funcOp,
89 LLVM::LLVMFuncOp wrapperFuncOp) {
90 auto argAttrs = funcOp.getAllArgAttrs();
91 if (!resultStructType) {
92 if (auto resAttrs = funcOp.getAllResultAttrs())
93 wrapperFuncOp.setAllResultAttrs(resAttrs);
94 if (argAttrs)
95 wrapperFuncOp.setAllArgAttrs(argAttrs);
96 } else {
97 SmallVector<Attribute> argAttributes;
98 // Only modify the argument and result attributes when the result is now
99 // an argument.
100 if (argAttrs) {
101 argAttributes.push_back(builder.getDictionaryAttr({}));
102 argAttributes.append(argAttrs.begin(), argAttrs.end());
103 wrapperFuncOp.setAllArgAttrs(argAttributes);
104 }
105 }
106 cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
107 .setVisibility(funcOp.getVisibility());
108}
109
110/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
111/// arguments instead of unpacked arguments. This function can be called from C
112/// by passing a pointer to a C struct corresponding to a memref descriptor.
113/// Similarly, returned memrefs are passed via pointers to a C struct that is
114/// passed as additional argument.
115/// Internally, the auxiliary function unpacks the descriptor into individual
116/// components and forwards them to `newFuncOp` and forwards the results to
117/// the extra arguments.
118static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
119 const LLVMTypeConverter &typeConverter,
120 FunctionOpInterface funcOp,
121 LLVM::LLVMFuncOp newFuncOp) {
122 auto type = cast<FunctionType>(funcOp.getFunctionType());
123 auto [wrapperFuncType, resultStructType] =
124 typeConverter.convertFunctionTypeCWrapper(type: type);
125
126 SmallVector<NamedAttribute> attributes;
127 filterFuncAttributes(funcOp, attributes);
128
129 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
130 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
131 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
132 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
133 propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
134
135 OpBuilder::InsertionGuard guard(rewriter);
136 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
137
138 SmallVector<Value, 8> args;
139 size_t argOffset = resultStructType ? 1 : 0;
140 for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
141 Value arg = wrapperFuncOp.getArgument(index + argOffset);
142 if (auto memrefType = dyn_cast<MemRefType>(argType)) {
143 Value loaded = rewriter.create<LLVM::LoadOp>(
144 loc, typeConverter.convertType(memrefType), arg);
145 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
146 continue;
147 }
148 if (isa<UnrankedMemRefType>(argType)) {
149 Value loaded = rewriter.create<LLVM::LoadOp>(
150 loc, typeConverter.convertType(argType), arg);
151 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
152 continue;
153 }
154
155 args.push_back(arg);
156 }
157
158 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
159
160 if (resultStructType) {
161 rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
162 wrapperFuncOp.getArgument(0));
163 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
164 } else {
165 rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
166 }
167}
168
169/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
170/// arguments instead of unpacked arguments. Creates a body for the (external)
171/// `newFuncOp` that allocates a memref descriptor on stack, packs the
172/// individual arguments into this descriptor and passes a pointer to it into
173/// the auxiliary function. If the result of the function cannot be directly
174/// returned, we write it to a special first argument that provides a pointer
175/// to a corresponding struct. This auxiliary external function is now
176/// compatible with functions defined in C using pointers to C structs
177/// corresponding to a memref descriptor.
178static void wrapExternalFunction(OpBuilder &builder, Location loc,
179 const LLVMTypeConverter &typeConverter,
180 FunctionOpInterface funcOp,
181 LLVM::LLVMFuncOp newFuncOp) {
182 OpBuilder::InsertionGuard guard(builder);
183
184 auto [wrapperType, resultStructType] =
185 typeConverter.convertFunctionTypeCWrapper(
186 type: cast<FunctionType>(funcOp.getFunctionType()));
187 // This conversion can only fail if it could not convert one of the argument
188 // types. But since it has been applied to a non-wrapper function before, it
189 // should have failed earlier and not reach this point at all.
190 assert(wrapperType && "unexpected type conversion failure");
191
192 SmallVector<NamedAttribute, 4> attributes;
193 filterFuncAttributes(funcOp, attributes);
194
195 // Create the auxiliary function.
196 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
197 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
198 wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
199 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
200 propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
201
202 // The wrapper that we synthetize here should only be visible in this module.
203 newFuncOp.setLinkage(LLVM::Linkage::Private);
204 builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
205
206 // Get a ValueRange containing arguments.
207 FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
208 SmallVector<Value, 8> args;
209 args.reserve(N: type.getNumInputs());
210 ValueRange wrapperArgsRange(newFuncOp.getArguments());
211
212 if (resultStructType) {
213 // Allocate the struct on the stack and pass the pointer.
214 Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
215 Value one = builder.create<LLVM::ConstantOp>(
216 loc, typeConverter.convertType(builder.getIndexType()),
217 builder.getIntegerAttr(builder.getIndexType(), 1));
218 Value result =
219 builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
220 args.push_back(Elt: result);
221 }
222
223 // Iterate over the inputs of the original function and pack values into
224 // memref descriptors if the original type is a memref.
225 for (Type input : type.getInputs()) {
226 Value arg;
227 int numToDrop = 1;
228 auto memRefType = dyn_cast<MemRefType>(input);
229 auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
230 if (memRefType || unrankedMemRefType) {
231 numToDrop = memRefType
232 ? MemRefDescriptor::getNumUnpackedValues(memRefType)
233 : UnrankedMemRefDescriptor::getNumUnpackedValues();
234 Value packed =
235 memRefType
236 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
237 wrapperArgsRange.take_front(numToDrop))
238 : UnrankedMemRefDescriptor::pack(
239 builder, loc, typeConverter, unrankedMemRefType,
240 wrapperArgsRange.take_front(numToDrop));
241
242 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
243 Value one = builder.create<LLVM::ConstantOp>(
244 loc, typeConverter.convertType(builder.getIndexType()),
245 builder.getIntegerAttr(builder.getIndexType(), 1));
246 Value allocated = builder.create<LLVM::AllocaOp>(
247 loc, ptrTy, packed.getType(), one, /*alignment=*/0);
248 builder.create<LLVM::StoreOp>(loc, packed, allocated);
249 arg = allocated;
250 } else {
251 arg = wrapperArgsRange[0];
252 }
253
254 args.push_back(arg);
255 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
256 }
257 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
258
259 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
260
261 if (resultStructType) {
262 Value result =
263 builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
264 builder.create<LLVM::ReturnOp>(loc, result);
265 } else {
266 builder.create<LLVM::ReturnOp>(loc, call.getResults());
267 }
268}
269
270/// Inserts `llvm.load` ops in the function body to restore the expected pointee
271/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
272/// to LLVM pointer types.
273static void restoreByValRefArgumentType(
274 ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
275 ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
276 ArrayRef<BlockArgument> oldBlockArgs, LLVM::LLVMFuncOp funcOp) {
277 // Nothing to do for function declarations.
278 if (funcOp.isExternal())
279 return;
280
281 ConversionPatternRewriter::InsertionGuard guard(rewriter);
282 rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
283
284 for (const auto &[arg, oldArg, byValRefAttr] :
285 llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) {
286 // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
287 if (!byValRefAttr)
288 continue;
289
290 // Insert load to retrieve the actual argument passed by value/reference.
291 assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
292 "Expected LLVM pointer type for argument with "
293 "`llvm.byval`/`llvm.byref` attribute");
294 Type resTy = typeConverter.convertType(
295 cast<TypeAttr>(byValRefAttr->getValue()).getValue());
296
297 auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
298 rewriter.replaceUsesOfBlockArgument(oldArg, valueArg);
299 }
300}
301
302FailureOr<LLVM::LLVMFuncOp>
303mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
304 ConversionPatternRewriter &rewriter,
305 const LLVMTypeConverter &converter) {
306 // Check the funcOp has `FunctionType`.
307 auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
308 if (!funcTy)
309 return rewriter.notifyMatchFailure(
310 funcOp, "Only support FunctionOpInterface with FunctionType");
311
312 // Keep track of the entry block arguments. They will be needed later.
313 SmallVector<BlockArgument> oldBlockArgs =
314 llvm::to_vector(funcOp.getArguments());
315
316 // Convert the original function arguments. They are converted using the
317 // LLVMTypeConverter provided to this legalization pattern.
318 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
319 // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
320 // overriden with an LLVM pointer type for later processing.
321 SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
322 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
323 auto llvmType = dyn_cast_or_null<LLVM::LLVMFunctionType>(
324 converter.convertFunctionSignature(
325 funcOp, varargsAttr && varargsAttr.getValue(),
326 shouldUseBarePtrCallConv(funcOp, &converter), result,
327 byValRefNonPtrAttrs));
328 if (!llvmType)
329 return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
330
331 // Check for unsupported variadic functions.
332 if (!shouldUseBarePtrCallConv(funcOp, &converter))
333 if (funcOp->getAttrOfType<UnitAttr>(
334 LLVM::LLVMDialect::getEmitCWrapperAttrName()))
335 if (llvmType.isVarArg())
336 return funcOp.emitError("C interface for variadic functions is not "
337 "supported yet.");
338
339 // Create an LLVM function, use external linkage by default until MLIR
340 // functions have linkage.
341 LLVM::Linkage linkage = LLVM::Linkage::External;
342 if (funcOp->hasAttr(linkageAttrName)) {
343 auto attr =
344 dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
345 if (!attr) {
346 funcOp->emitError() << "Contains " << linkageAttrName
347 << " attribute not of type LLVM::LinkageAttr";
348 return rewriter.notifyMatchFailure(
349 funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
350 }
351 linkage = attr.getLinkage();
352 }
353
354 // Check for invalid attributes.
355 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
356 if (funcOp->hasAttr(readnoneAttrName)) {
357 auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
358 if (!attr) {
359 funcOp->emitError() << "Contains " << readnoneAttrName
360 << " attribute not of type UnitAttr";
361 return rewriter.notifyMatchFailure(
362 funcOp, "Contains readnone attribute not of type UnitAttr");
363 }
364 }
365
366 SmallVector<NamedAttribute, 4> attributes;
367 filterFuncAttributes(funcOp, attributes);
368 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
369 funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
370 /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
371 attributes);
372 cast<FunctionOpInterface>(newFuncOp.getOperation())
373 .setVisibility(funcOp.getVisibility());
374
375 // Create a memory effect attribute corresponding to readnone.
376 if (funcOp->hasAttr(readnoneAttrName)) {
377 auto memoryAttr = LLVM::MemoryEffectsAttr::get(
378 rewriter.getContext(),
379 {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
380 LLVM::ModRefInfo::NoModRef});
381 newFuncOp.setMemoryEffectsAttr(memoryAttr);
382 }
383
384 // Propagate argument/result attributes to all converted arguments/result
385 // obtained after converting a given original argument/result.
386 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
387 assert(!resAttrDicts.empty() && "expected array to be non-empty");
388 if (funcOp.getNumResults() == 1)
389 newFuncOp.setAllResultAttrs(resAttrDicts);
390 }
391 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
392 SmallVector<Attribute> newArgAttrs(
393 cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
394 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
395 // Some LLVM IR attribute have a type attached to them. During FuncOp ->
396 // LLVMFuncOp conversion these types may have changed. Account for that
397 // change by converting attributes' types as well.
398 SmallVector<NamedAttribute, 4> convertedAttrs;
399 auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
400 convertedAttrs.reserve(N: attrsDict.size());
401 for (const NamedAttribute &attr : attrsDict) {
402 const auto convert = [&](const NamedAttribute &attr) {
403 return TypeAttr::get(converter.convertType(
404 cast<TypeAttr>(attr.getValue()).getValue()));
405 };
406 if (attr.getName().getValue() ==
407 LLVM::LLVMDialect::getByValAttrName()) {
408 convertedAttrs.push_back(rewriter.getNamedAttr(
409 LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
410 } else if (attr.getName().getValue() ==
411 LLVM::LLVMDialect::getByRefAttrName()) {
412 convertedAttrs.push_back(rewriter.getNamedAttr(
413 LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
414 } else if (attr.getName().getValue() ==
415 LLVM::LLVMDialect::getStructRetAttrName()) {
416 convertedAttrs.push_back(rewriter.getNamedAttr(
417 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
418 } else if (attr.getName().getValue() ==
419 LLVM::LLVMDialect::getInAllocaAttrName()) {
420 convertedAttrs.push_back(rewriter.getNamedAttr(
421 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
422 } else {
423 convertedAttrs.push_back(attr);
424 }
425 }
426 auto mapping = result.getInputMapping(input: i);
427 assert(mapping && "unexpected deletion of function argument");
428 // Only attach the new argument attributes if there is a one-to-one
429 // mapping from old to new types. Otherwise, attributes might be
430 // attached to types that they do not support.
431 if (mapping->size == 1) {
432 newArgAttrs[mapping->inputNo] =
433 DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
434 continue;
435 }
436 // TODO: Implement custom handling for types that expand to multiple
437 // function arguments.
438 for (size_t j = 0; j < mapping->size; ++j)
439 newArgAttrs[mapping->inputNo + j] =
440 DictionaryAttr::get(rewriter.getContext(), {});
441 }
442 if (!newArgAttrs.empty())
443 newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
444 }
445
446 rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
447 newFuncOp.end());
448 // Convert just the entry block. The remaining unstructured control flow is
449 // converted by ControlFlowToLLVM.
450 if (!newFuncOp.getBody().empty())
451 rewriter.applySignatureConversion(block: &newFuncOp.getBody().front(), conversion&: result,
452 converter: &converter);
453
454 // Fix the type mismatch between the materialized `llvm.ptr` and the expected
455 // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
456 // function arguments.
457 restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
458 oldBlockArgs, newFuncOp);
459
460 if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
461 if (funcOp->getAttrOfType<UnitAttr>(
462 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
463 if (newFuncOp.isExternal())
464 wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
465 newFuncOp);
466 else
467 wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
468 newFuncOp);
469 }
470 }
471
472 return newFuncOp;
473}
474
475namespace {
476
477/// FuncOp legalization pattern that converts MemRef arguments to pointers to
478/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
479/// information.
480struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
481 FuncOpConversion(const LLVMTypeConverter &converter)
482 : ConvertOpToLLVMPattern(converter) {}
483
484 LogicalResult
485 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter) const override {
487 FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
488 cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
489 *getTypeConverter());
490 if (failed(Result: newFuncOp))
491 return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
492
493 rewriter.eraseOp(op: funcOp);
494 return success();
495 }
496};
497
498struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
499 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
500
501 LogicalResult
502 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
503 ConversionPatternRewriter &rewriter) const override {
504 auto type = typeConverter->convertType(op.getResult().getType());
505 if (!type || !LLVM::isCompatibleType(type: type))
506 return rewriter.notifyMatchFailure(op, "failed to convert result type");
507
508 auto newOp =
509 rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
510 for (const NamedAttribute &attr : op->getAttrs()) {
511 if (attr.getName().strref() == "value")
512 continue;
513 newOp->setAttr(attr.getName(), attr.getValue());
514 }
515 rewriter.replaceOp(op, newOp->getResults());
516 return success();
517 }
518};
519
520// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
521// passes the pointer to the MemRef across function boundaries.
522template <typename CallOpType>
523struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
524 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
525 using Super = CallOpInterfaceLowering<CallOpType>;
526 using Base = ConvertOpToLLVMPattern<CallOpType>;
527
528 LogicalResult matchAndRewriteImpl(CallOpType callOp,
529 typename CallOpType::Adaptor adaptor,
530 ConversionPatternRewriter &rewriter,
531 bool useBarePtrCallConv = false) const {
532 // Pack the result types into a struct.
533 Type packedResult = nullptr;
534 unsigned numResults = callOp.getNumResults();
535 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
536
537 if (numResults != 0) {
538 if (!(packedResult = this->getTypeConverter()->packFunctionResults(
539 resultTypes, useBarePtrCallConv)))
540 return failure();
541 }
542
543 if (useBarePtrCallConv) {
544 for (auto it : callOp->getOperands()) {
545 Type operandType = it.getType();
546 if (isa<UnrankedMemRefType>(Val: operandType)) {
547 // Unranked memref is not supported in the bare pointer calling
548 // convention.
549 return failure();
550 }
551 }
552 }
553 auto promoted = this->getTypeConverter()->promoteOperands(
554 callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
555 adaptor.getOperands(), rewriter, useBarePtrCallConv);
556 auto newOp = rewriter.create<LLVM::CallOp>(
557 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
558 promoted, callOp->getAttrs());
559
560 newOp.getProperties().operandSegmentSizes = {
561 static_cast<int32_t>(promoted.size()), 0};
562 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
563
564 SmallVector<Value, 4> results;
565 if (numResults < 2) {
566 // If < 2 results, packing did not do anything and we can just return.
567 results.append(newOp.result_begin(), newOp.result_end());
568 } else {
569 // Otherwise, it had been converted to an operation producing a structure.
570 // Extract individual results from the structure and return them as list.
571 results.reserve(N: numResults);
572 for (unsigned i = 0; i < numResults; ++i) {
573 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
574 callOp.getLoc(), newOp->getResult(0), i));
575 }
576 }
577
578 if (useBarePtrCallConv) {
579 // For the bare-ptr calling convention, promote memref results to
580 // descriptors.
581 assert(results.size() == resultTypes.size() &&
582 "The number of arguments and types doesn't match");
583 this->getTypeConverter()->promoteBarePtrsToDescriptors(
584 rewriter, callOp.getLoc(), resultTypes, results);
585 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
586 resultTypes, results,
587 /*toDynamic=*/false))) {
588 return failure();
589 }
590
591 rewriter.replaceOp(callOp, results);
592 return success();
593 }
594};
595
596class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
597public:
598 CallOpLowering(const LLVMTypeConverter &typeConverter,
599 // Can be nullptr.
600 const SymbolTable *symbolTable, PatternBenefit benefit = 1)
601 : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
602 symbolTable(symbolTable) {}
603
604 LogicalResult
605 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
606 ConversionPatternRewriter &rewriter) const override {
607 bool useBarePtrCallConv = false;
608 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
609 useBarePtrCallConv = true;
610 } else if (symbolTable != nullptr) {
611 // Fast lookup.
612 Operation *callee =
613 symbolTable->lookup(callOp.getCalleeAttr().getValue());
614 useBarePtrCallConv =
615 callee != nullptr && callee->hasAttr(name: barePtrAttrName);
616 } else {
617 // Warning: This is a linear lookup.
618 Operation *callee =
619 SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
620 useBarePtrCallConv =
621 callee != nullptr && callee->hasAttr(name: barePtrAttrName);
622 }
623 return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
624 }
625
626private:
627 const SymbolTable *symbolTable = nullptr;
628};
629
630struct CallIndirectOpLowering
631 : public CallOpInterfaceLowering<func::CallIndirectOp> {
632 using Super::Super;
633
634 LogicalResult
635 matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
636 ConversionPatternRewriter &rewriter) const override {
637 return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
638 }
639};
640
641struct UnrealizedConversionCastOpLowering
642 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
643 using ConvertOpToLLVMPattern<
644 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
645
646 LogicalResult
647 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
648 ConversionPatternRewriter &rewriter) const override {
649 SmallVector<Type> convertedTypes;
650 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
651 convertedTypes)) &&
652 convertedTypes == adaptor.getInputs().getTypes()) {
653 rewriter.replaceOp(op, adaptor.getInputs());
654 return success();
655 }
656
657 convertedTypes.clear();
658 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
659 convertedTypes)) &&
660 convertedTypes == op.getOutputs().getType()) {
661 rewriter.replaceOp(op, adaptor.getInputs());
662 return success();
663 }
664 return failure();
665 }
666};
667
668// Special lowering pattern for `ReturnOps`. Unlike all other operations,
669// `ReturnOp` interacts with the function signature and must have as many
670// operands as the function has return values. Because in LLVM IR, functions
671// can only return 0 or 1 value, we pack multiple values into a structure type.
672// Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
673// necessary before returning it
674struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
675 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
676
677 LogicalResult
678 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
679 ConversionPatternRewriter &rewriter) const override {
680 Location loc = op.getLoc();
681 unsigned numArguments = op.getNumOperands();
682 SmallVector<Value, 4> updatedOperands;
683
684 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
685 bool useBarePtrCallConv =
686 shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
687 if (useBarePtrCallConv) {
688 // For the bare-ptr calling convention, extract the aligned pointer to
689 // be returned from the memref descriptor.
690 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
691 Type oldTy = std::get<0>(it).getType();
692 Value newOperand = std::get<1>(it);
693 if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
694 cast<BaseMemRefType>(oldTy))) {
695 MemRefDescriptor memrefDesc(newOperand);
696 newOperand = memrefDesc.allocatedPtr(rewriter, loc);
697 } else if (isa<UnrankedMemRefType>(oldTy)) {
698 // Unranked memref is not supported in the bare pointer calling
699 // convention.
700 return failure();
701 }
702 updatedOperands.push_back(newOperand);
703 }
704 } else {
705 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
706 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
707 updatedOperands,
708 /*toDynamic=*/true);
709 }
710
711 // If ReturnOp has 0 or 1 operand, create it and return immediately.
712 if (numArguments <= 1) {
713 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
714 op, TypeRange(), updatedOperands, op->getAttrs());
715 return success();
716 }
717
718 // Otherwise, we need to pack the arguments into an LLVM struct type before
719 // returning.
720 auto packedType = getTypeConverter()->packFunctionResults(
721 op.getOperandTypes(), useBarePtrCallConv);
722 if (!packedType) {
723 return rewriter.notifyMatchFailure(op, "could not convert result types");
724 }
725
726 Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
727 for (auto [idx, operand] : llvm::enumerate(First&: updatedOperands)) {
728 packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
729 }
730 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
731 op->getAttrs());
732 return success();
733 }
734};
735} // namespace
736
737void mlir::populateFuncToLLVMFuncOpConversionPattern(
738 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
739 patterns.add<FuncOpConversion>(arg: converter);
740}
741
742void mlir::populateFuncToLLVMConversionPatterns(
743 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
744 const SymbolTable *symbolTable) {
745 populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
746 patterns.add<CallIndirectOpLowering>(arg: converter);
747 patterns.add<CallOpLowering>(arg: converter, args&: symbolTable);
748 patterns.add<ConstantOpLowering>(arg: converter);
749 patterns.add<ReturnOpLowering>(arg: converter);
750}
751
752namespace {
753/// A pass converting Func operations into the LLVM IR dialect.
754struct ConvertFuncToLLVMPass
755 : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
756 using Base::Base;
757
758 /// Run the dialect converter on the module.
759 void runOnOperation() override {
760 ModuleOp m = getOperation();
761 StringRef dataLayout;
762 auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
763 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
764 if (dataLayoutAttr)
765 dataLayout = dataLayoutAttr.getValue();
766
767 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
768 dataLayout, [this](const Twine &message) {
769 getOperation().emitError() << message.str();
770 }))) {
771 signalPassFailure();
772 return;
773 }
774
775 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
776
777 LowerToLLVMOptions options(&getContext(),
778 dataLayoutAnalysis.getAtOrAbove(m));
779 options.useBarePtrCallConv = useBarePtrCallConv;
780 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
781 options.overrideIndexBitwidth(indexBitwidth);
782 options.dataLayout = llvm::DataLayout(dataLayout);
783
784 LLVMTypeConverter typeConverter(&getContext(), options,
785 &dataLayoutAnalysis);
786
787 std::optional<SymbolTable> optSymbolTable = std::nullopt;
788 const SymbolTable *symbolTable = nullptr;
789 if (!options.useBarePtrCallConv) {
790 optSymbolTable.emplace(m);
791 symbolTable = &optSymbolTable.value();
792 }
793
794 RewritePatternSet patterns(&getContext());
795 populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns, symbolTable);
796
797 LLVMConversionTarget target(getContext());
798 if (failed(applyPartialConversion(m, target, std::move(patterns))))
799 signalPassFailure();
800 }
801};
802
803struct SetLLVMModuleDataLayoutPass
804 : public impl::SetLLVMModuleDataLayoutPassBase<
805 SetLLVMModuleDataLayoutPass> {
806 using Base::Base;
807
808 /// Run the dialect converter on the module.
809 void runOnOperation() override {
810 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
811 this->dataLayout, [this](const Twine &message) {
812 getOperation().emitError() << message.str();
813 }))) {
814 signalPassFailure();
815 return;
816 }
817 ModuleOp m = getOperation();
818 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
819 StringAttr::get(m.getContext(), this->dataLayout));
820 }
821};
822} // namespace
823
824//===----------------------------------------------------------------------===//
825// ConvertToLLVMPatternInterface implementation
826//===----------------------------------------------------------------------===//
827
828namespace {
829/// Implement the interface to convert Func to LLVM.
830struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
831 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
832 /// Hook for derived dialect interface to provide conversion patterns
833 /// and mark dialect legal for the conversion target.
834 void populateConvertToLLVMConversionPatterns(
835 ConversionTarget &target, LLVMTypeConverter &typeConverter,
836 RewritePatternSet &patterns) const final {
837 populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns);
838 }
839};
840} // namespace
841
842void mlir::registerConvertFuncToLLVMInterface(DialectRegistry &registry) {
843 registry.addExtension(extensionFn: +[](MLIRContext *ctx, func::FuncDialect *dialect) {
844 dialect->addInterfaces<FuncToLLVMDialectInterface>();
845 });
846}
847

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp