1//===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR Func and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
15
16#include "mlir/Analysis/DataLayoutAnalysis.h"
17#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
18#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
19#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
22#include "mlir/Conversion/LLVMCommon/Pattern.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
25#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
27#include "mlir/IR/Attributes.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinAttributes.h"
30#include "mlir/IR/BuiltinOps.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/IR/SymbolTable.h"
33#include "mlir/IR/TypeUtilities.h"
34#include "mlir/Transforms/DialectConversion.h"
35#include "mlir/Transforms/Passes.h"
36#include "llvm/ADT/SmallVector.h"
37#include "llvm/IR/Type.h"
38#include "llvm/Support/FormatVariadic.h"
39#include <optional>
40
41namespace mlir {
42#define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
43#define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
44#include "mlir/Conversion/Passes.h.inc"
45} // namespace mlir
46
47using namespace mlir;
48
49#define PASS_NAME "convert-func-to-llvm"
50
51static constexpr StringRef varargsAttrName = "func.varargs";
52static constexpr StringRef linkageAttrName = "llvm.linkage";
53static constexpr StringRef barePtrAttrName = "llvm.bareptr";
54
55/// Return `true` if the `op` should use bare pointer calling convention.
56static bool shouldUseBarePtrCallConv(Operation *op,
57 const LLVMTypeConverter *typeConverter) {
58 return (op && op->hasAttr(name: barePtrAttrName)) ||
59 typeConverter->getOptions().useBarePtrCallConv;
60}
61
62/// Only retain those attributes that are not constructed by
63/// `LLVMFuncOp::build`.
64static void filterFuncAttributes(FunctionOpInterface func,
65 SmallVectorImpl<NamedAttribute> &result) {
66 for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
67 if (attr.getName() == linkageAttrName ||
68 attr.getName() == varargsAttrName ||
69 attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
70 continue;
71 result.push_back(Elt: attr);
72 }
73}
74
75/// Propagate argument/results attributes.
76static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
77 FunctionOpInterface funcOp,
78 LLVM::LLVMFuncOp wrapperFuncOp) {
79 auto argAttrs = funcOp.getAllArgAttrs();
80 if (!resultStructType) {
81 if (auto resAttrs = funcOp.getAllResultAttrs())
82 wrapperFuncOp.setAllResultAttrs(resAttrs);
83 if (argAttrs)
84 wrapperFuncOp.setAllArgAttrs(argAttrs);
85 } else {
86 SmallVector<Attribute> argAttributes;
87 // Only modify the argument and result attributes when the result is now
88 // an argument.
89 if (argAttrs) {
90 argAttributes.push_back(Elt: builder.getDictionaryAttr(value: {}));
91 argAttributes.append(in_start: argAttrs.begin(), in_end: argAttrs.end());
92 wrapperFuncOp.setAllArgAttrs(argAttributes);
93 }
94 }
95 cast<FunctionOpInterface>(Val: wrapperFuncOp.getOperation())
96 .setVisibility(funcOp.getVisibility());
97}
98
99/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
100/// arguments instead of unpacked arguments. This function can be called from C
101/// by passing a pointer to a C struct corresponding to a memref descriptor.
102/// Similarly, returned memrefs are passed via pointers to a C struct that is
103/// passed as additional argument.
104/// Internally, the auxiliary function unpacks the descriptor into individual
105/// components and forwards them to `newFuncOp` and forwards the results to
106/// the extra arguments.
107static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
108 const LLVMTypeConverter &typeConverter,
109 FunctionOpInterface funcOp,
110 LLVM::LLVMFuncOp newFuncOp) {
111 auto type = cast<FunctionType>(Val: funcOp.getFunctionType());
112 auto [wrapperFuncType, resultStructType] =
113 typeConverter.convertFunctionTypeCWrapper(type);
114
115 SmallVector<NamedAttribute> attributes;
116 filterFuncAttributes(func: funcOp, result&: attributes);
117
118 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
119 location: loc, args: llvm::formatv(Fmt: "_mlir_ciface_{0}", Vals: funcOp.getName()).str(),
120 args&: wrapperFuncType, args: LLVM::Linkage::External, /*dsoLocal=*/args: false,
121 /*cconv=*/args: LLVM::CConv::C, /*comdat=*/args: nullptr, args&: attributes);
122 propagateArgResAttrs(builder&: rewriter, resultStructType: !!resultStructType, funcOp, wrapperFuncOp);
123
124 OpBuilder::InsertionGuard guard(rewriter);
125 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(builder&: rewriter));
126
127 SmallVector<Value, 8> args;
128 size_t argOffset = resultStructType ? 1 : 0;
129 for (auto [index, argType] : llvm::enumerate(First: type.getInputs())) {
130 Value arg = wrapperFuncOp.getArgument(idx: index + argOffset);
131 if (auto memrefType = dyn_cast<MemRefType>(Val: argType)) {
132 Value loaded = rewriter.create<LLVM::LoadOp>(
133 location: loc, args: typeConverter.convertType(t: memrefType), args&: arg);
134 MemRefDescriptor::unpack(builder&: rewriter, loc, packed: loaded, type: memrefType, results&: args);
135 continue;
136 }
137 if (isa<UnrankedMemRefType>(Val: argType)) {
138 Value loaded = rewriter.create<LLVM::LoadOp>(
139 location: loc, args: typeConverter.convertType(t: argType), args&: arg);
140 UnrankedMemRefDescriptor::unpack(builder&: rewriter, loc, packed: loaded, results&: args);
141 continue;
142 }
143
144 args.push_back(Elt: arg);
145 }
146
147 auto call = rewriter.create<LLVM::CallOp>(location: loc, args&: newFuncOp, args);
148
149 if (resultStructType) {
150 rewriter.create<LLVM::StoreOp>(location: loc, args: call.getResult(),
151 args: wrapperFuncOp.getArgument(idx: 0));
152 rewriter.create<LLVM::ReturnOp>(location: loc, args: ValueRange{});
153 } else {
154 rewriter.create<LLVM::ReturnOp>(location: loc, args: call.getResults());
155 }
156}
157
158/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
159/// arguments instead of unpacked arguments. Creates a body for the (external)
160/// `newFuncOp` that allocates a memref descriptor on stack, packs the
161/// individual arguments into this descriptor and passes a pointer to it into
162/// the auxiliary function. If the result of the function cannot be directly
163/// returned, we write it to a special first argument that provides a pointer
164/// to a corresponding struct. This auxiliary external function is now
165/// compatible with functions defined in C using pointers to C structs
166/// corresponding to a memref descriptor.
167static void wrapExternalFunction(OpBuilder &builder, Location loc,
168 const LLVMTypeConverter &typeConverter,
169 FunctionOpInterface funcOp,
170 LLVM::LLVMFuncOp newFuncOp) {
171 OpBuilder::InsertionGuard guard(builder);
172
173 auto [wrapperType, resultStructType] =
174 typeConverter.convertFunctionTypeCWrapper(
175 type: cast<FunctionType>(Val: funcOp.getFunctionType()));
176 // This conversion can only fail if it could not convert one of the argument
177 // types. But since it has been applied to a non-wrapper function before, it
178 // should have failed earlier and not reach this point at all.
179 assert(wrapperType && "unexpected type conversion failure");
180
181 SmallVector<NamedAttribute, 4> attributes;
182 filterFuncAttributes(func: funcOp, result&: attributes);
183
184 // Create the auxiliary function.
185 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
186 location: loc, args: llvm::formatv(Fmt: "_mlir_ciface_{0}", Vals: funcOp.getName()).str(),
187 args&: wrapperType, args: LLVM::Linkage::External, /*dsoLocal=*/args: false,
188 /*cconv=*/args: LLVM::CConv::C, /*comdat=*/args: nullptr, args&: attributes);
189 propagateArgResAttrs(builder, resultStructType: !!resultStructType, funcOp, wrapperFuncOp: wrapperFunc);
190
191 // The wrapper that we synthetize here should only be visible in this module.
192 newFuncOp.setLinkage(LLVM::Linkage::Private);
193 builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
194
195 // Get a ValueRange containing arguments.
196 FunctionType type = cast<FunctionType>(Val: funcOp.getFunctionType());
197 SmallVector<Value, 8> args;
198 args.reserve(N: type.getNumInputs());
199 ValueRange wrapperArgsRange(newFuncOp.getArguments());
200
201 if (resultStructType) {
202 // Allocate the struct on the stack and pass the pointer.
203 Type resultType = cast<LLVM::LLVMFunctionType>(Val&: wrapperType).getParamType(i: 0);
204 Value one = builder.create<LLVM::ConstantOp>(
205 location: loc, args: typeConverter.convertType(t: builder.getIndexType()),
206 args: builder.getIntegerAttr(type: builder.getIndexType(), value: 1));
207 Value result =
208 builder.create<LLVM::AllocaOp>(location: loc, args&: resultType, args&: resultStructType, args&: one);
209 args.push_back(Elt: result);
210 }
211
212 // Iterate over the inputs of the original function and pack values into
213 // memref descriptors if the original type is a memref.
214 for (Type input : type.getInputs()) {
215 Value arg;
216 int numToDrop = 1;
217 auto memRefType = dyn_cast<MemRefType>(Val&: input);
218 auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(Val&: input);
219 if (memRefType || unrankedMemRefType) {
220 numToDrop = memRefType
221 ? MemRefDescriptor::getNumUnpackedValues(type: memRefType)
222 : UnrankedMemRefDescriptor::getNumUnpackedValues();
223 Value packed =
224 memRefType
225 ? MemRefDescriptor::pack(builder, loc, converter: typeConverter, type: memRefType,
226 values: wrapperArgsRange.take_front(n: numToDrop))
227 : UnrankedMemRefDescriptor::pack(
228 builder, loc, converter: typeConverter, type: unrankedMemRefType,
229 values: wrapperArgsRange.take_front(n: numToDrop));
230
231 auto ptrTy = LLVM::LLVMPointerType::get(context: builder.getContext());
232 Value one = builder.create<LLVM::ConstantOp>(
233 location: loc, args: typeConverter.convertType(t: builder.getIndexType()),
234 args: builder.getIntegerAttr(type: builder.getIndexType(), value: 1));
235 Value allocated = builder.create<LLVM::AllocaOp>(
236 location: loc, args&: ptrTy, args: packed.getType(), args&: one, /*alignment=*/args: 0);
237 builder.create<LLVM::StoreOp>(location: loc, args&: packed, args&: allocated);
238 arg = allocated;
239 } else {
240 arg = wrapperArgsRange[0];
241 }
242
243 args.push_back(Elt: arg);
244 wrapperArgsRange = wrapperArgsRange.drop_front(n: numToDrop);
245 }
246 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
247
248 auto call = builder.create<LLVM::CallOp>(location: loc, args&: wrapperFunc, args);
249
250 if (resultStructType) {
251 Value result =
252 builder.create<LLVM::LoadOp>(location: loc, args&: resultStructType, args&: args.front());
253 builder.create<LLVM::ReturnOp>(location: loc, args&: result);
254 } else {
255 builder.create<LLVM::ReturnOp>(location: loc, args: call.getResults());
256 }
257}
258
259/// Inserts `llvm.load` ops in the function body to restore the expected pointee
260/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
261/// to LLVM pointer types.
262static void restoreByValRefArgumentType(
263 ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
264 ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
265 LLVM::LLVMFuncOp funcOp) {
266 // Nothing to do for function declarations.
267 if (funcOp.isExternal())
268 return;
269
270 ConversionPatternRewriter::InsertionGuard guard(rewriter);
271 rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
272
273 for (const auto &[arg, byValRefAttr] :
274 llvm::zip(t: funcOp.getArguments(), u&: byValRefNonPtrAttrs)) {
275 // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
276 if (!byValRefAttr)
277 continue;
278
279 // Insert load to retrieve the actual argument passed by value/reference.
280 assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
281 "Expected LLVM pointer type for argument with "
282 "`llvm.byval`/`llvm.byref` attribute");
283 Type resTy = typeConverter.convertType(
284 t: cast<TypeAttr>(Val: byValRefAttr->getValue()).getValue());
285
286 Value valueArg = rewriter.create<LLVM::LoadOp>(location: arg.getLoc(), args&: resTy, args&: arg);
287 rewriter.replaceUsesOfBlockArgument(from: arg, to: valueArg);
288 }
289}
290
291FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
292 FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
293 const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
294 // Check the funcOp has `FunctionType`.
295 auto funcTy = dyn_cast<FunctionType>(Val: funcOp.getFunctionType());
296 if (!funcTy)
297 return rewriter.notifyMatchFailure(
298 arg&: funcOp, msg: "Only support FunctionOpInterface with FunctionType");
299
300 // Convert the original function arguments. They are converted using the
301 // LLVMTypeConverter provided to this legalization pattern.
302 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(name: varargsAttrName);
303 // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
304 // overriden with an LLVM pointer type for later processing.
305 SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
306 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
307 auto llvmType = dyn_cast_or_null<LLVM::LLVMFunctionType>(
308 Val: converter.convertFunctionSignature(
309 funcOp, isVariadic: varargsAttr && varargsAttr.getValue(),
310 useBarePtrCallConv: shouldUseBarePtrCallConv(op: funcOp, typeConverter: &converter), result,
311 byValRefNonPtrAttrs));
312 if (!llvmType)
313 return rewriter.notifyMatchFailure(arg&: funcOp, msg: "signature conversion failed");
314
315 // Check for unsupported variadic functions.
316 if (!shouldUseBarePtrCallConv(op: funcOp, typeConverter: &converter))
317 if (funcOp->getAttrOfType<UnitAttr>(
318 name: LLVM::LLVMDialect::getEmitCWrapperAttrName()))
319 if (llvmType.isVarArg())
320 return funcOp.emitError(message: "C interface for variadic functions is not "
321 "supported yet.");
322
323 // Create an LLVM function, use external linkage by default until MLIR
324 // functions have linkage.
325 LLVM::Linkage linkage = LLVM::Linkage::External;
326 if (funcOp->hasAttr(name: linkageAttrName)) {
327 auto attr =
328 dyn_cast<mlir::LLVM::LinkageAttr>(Val: funcOp->getAttr(name: linkageAttrName));
329 if (!attr) {
330 funcOp->emitError() << "Contains " << linkageAttrName
331 << " attribute not of type LLVM::LinkageAttr";
332 return rewriter.notifyMatchFailure(
333 arg&: funcOp, msg: "Contains linkage attribute not of type LLVM::LinkageAttr");
334 }
335 linkage = attr.getLinkage();
336 }
337
338 // Check for invalid attributes.
339 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
340 if (funcOp->hasAttr(name: readnoneAttrName)) {
341 auto attr = funcOp->getAttrOfType<UnitAttr>(name: readnoneAttrName);
342 if (!attr) {
343 funcOp->emitError() << "Contains " << readnoneAttrName
344 << " attribute not of type UnitAttr";
345 return rewriter.notifyMatchFailure(
346 arg&: funcOp, msg: "Contains readnone attribute not of type UnitAttr");
347 }
348 }
349
350 SmallVector<NamedAttribute, 4> attributes;
351 filterFuncAttributes(func: funcOp, result&: attributes);
352
353 Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
354
355 if (symbolTables && symbolTableOp) {
356 SymbolTable &symbolTable = symbolTables->getSymbolTable(op: symbolTableOp);
357 symbolTable.remove(op: funcOp);
358 }
359
360 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
361 location: funcOp.getLoc(), args: funcOp.getName(), args&: llvmType, args&: linkage,
362 /*dsoLocal=*/args: false, /*cconv=*/args: LLVM::CConv::C, /*comdat=*/args: nullptr,
363 args&: attributes);
364
365 if (symbolTables && symbolTableOp) {
366 auto ip = rewriter.getInsertionPoint();
367 SymbolTable &symbolTable = symbolTables->getSymbolTable(op: symbolTableOp);
368 symbolTable.insert(symbol: newFuncOp, insertPt: ip);
369 }
370
371 cast<FunctionOpInterface>(Val: newFuncOp.getOperation())
372 .setVisibility(funcOp.getVisibility());
373
374 // Create a memory effect attribute corresponding to readnone.
375 if (funcOp->hasAttr(name: readnoneAttrName)) {
376 auto memoryAttr = LLVM::MemoryEffectsAttr::get(
377 context: rewriter.getContext(),
378 memInfoArgs: {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
379 LLVM::ModRefInfo::NoModRef});
380 newFuncOp.setMemoryEffectsAttr(memoryAttr);
381 }
382
383 // Propagate argument/result attributes to all converted arguments/result
384 // obtained after converting a given original argument/result.
385 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
386 assert(!resAttrDicts.empty() && "expected array to be non-empty");
387 if (funcOp.getNumResults() == 1)
388 newFuncOp.setAllResultAttrs(resAttrDicts);
389 }
390 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
391 SmallVector<Attribute> newArgAttrs(
392 cast<LLVM::LLVMFunctionType>(Val&: llvmType).getNumParams());
393 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
394 // Some LLVM IR attribute have a type attached to them. During FuncOp ->
395 // LLVMFuncOp conversion these types may have changed. Account for that
396 // change by converting attributes' types as well.
397 SmallVector<NamedAttribute, 4> convertedAttrs;
398 auto attrsDict = cast<DictionaryAttr>(Val: argAttrDicts[i]);
399 convertedAttrs.reserve(N: attrsDict.size());
400 for (const NamedAttribute &attr : attrsDict) {
401 const auto convert = [&](const NamedAttribute &attr) {
402 return TypeAttr::get(type: converter.convertType(
403 t: cast<TypeAttr>(Val: attr.getValue()).getValue()));
404 };
405 if (attr.getName().getValue() ==
406 LLVM::LLVMDialect::getByValAttrName()) {
407 convertedAttrs.push_back(Elt: rewriter.getNamedAttr(
408 name: LLVM::LLVMDialect::getByValAttrName(), val: convert(attr)));
409 } else if (attr.getName().getValue() ==
410 LLVM::LLVMDialect::getByRefAttrName()) {
411 convertedAttrs.push_back(Elt: rewriter.getNamedAttr(
412 name: LLVM::LLVMDialect::getByRefAttrName(), val: convert(attr)));
413 } else if (attr.getName().getValue() ==
414 LLVM::LLVMDialect::getStructRetAttrName()) {
415 convertedAttrs.push_back(Elt: rewriter.getNamedAttr(
416 name: LLVM::LLVMDialect::getStructRetAttrName(), val: convert(attr)));
417 } else if (attr.getName().getValue() ==
418 LLVM::LLVMDialect::getInAllocaAttrName()) {
419 convertedAttrs.push_back(Elt: rewriter.getNamedAttr(
420 name: LLVM::LLVMDialect::getInAllocaAttrName(), val: convert(attr)));
421 } else {
422 convertedAttrs.push_back(Elt: attr);
423 }
424 }
425 auto mapping = result.getInputMapping(input: i);
426 assert(mapping && "unexpected deletion of function argument");
427 // Only attach the new argument attributes if there is a one-to-one
428 // mapping from old to new types. Otherwise, attributes might be
429 // attached to types that they do not support.
430 if (mapping->size == 1) {
431 newArgAttrs[mapping->inputNo] =
432 DictionaryAttr::get(context: rewriter.getContext(), value: convertedAttrs);
433 continue;
434 }
435 // TODO: Implement custom handling for types that expand to multiple
436 // function arguments.
437 for (size_t j = 0; j < mapping->size; ++j)
438 newArgAttrs[mapping->inputNo + j] =
439 DictionaryAttr::get(context: rewriter.getContext(), value: {});
440 }
441 if (!newArgAttrs.empty())
442 newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(value: newArgAttrs));
443 }
444
445 rewriter.inlineRegionBefore(region&: funcOp.getFunctionBody(), parent&: newFuncOp.getBody(),
446 before: newFuncOp.end());
447 // Convert just the entry block. The remaining unstructured control flow is
448 // converted by ControlFlowToLLVM.
449 if (!newFuncOp.getBody().empty())
450 rewriter.applySignatureConversion(block: &newFuncOp.getBody().front(), conversion&: result,
451 converter: &converter);
452
453 // Fix the type mismatch between the materialized `llvm.ptr` and the expected
454 // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
455 // function arguments.
456 restoreByValRefArgumentType(rewriter, typeConverter: converter, byValRefNonPtrAttrs,
457 funcOp: newFuncOp);
458
459 if (!shouldUseBarePtrCallConv(op: funcOp, typeConverter: &converter)) {
460 if (funcOp->getAttrOfType<UnitAttr>(
461 name: LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
462 if (newFuncOp.isExternal())
463 wrapExternalFunction(builder&: rewriter, loc: funcOp->getLoc(), typeConverter: converter, funcOp,
464 newFuncOp);
465 else
466 wrapForExternalCallers(rewriter, loc: funcOp->getLoc(), typeConverter: converter, funcOp,
467 newFuncOp);
468 }
469 }
470
471 return newFuncOp;
472}
473
474namespace {
475
476/// FuncOp legalization pattern that converts MemRef arguments to pointers to
477/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
478/// information.
479class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
480 SymbolTableCollection *symbolTables = nullptr;
481
482public:
483 explicit FuncOpConversion(const LLVMTypeConverter &converter,
484 SymbolTableCollection *symbolTables = nullptr)
485 : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
486
487 LogicalResult
488 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter) const override {
490 FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
491 funcOp: cast<FunctionOpInterface>(Val: funcOp.getOperation()), rewriter,
492 converter: *getTypeConverter(), symbolTables);
493 if (failed(Result: newFuncOp))
494 return rewriter.notifyMatchFailure(arg&: funcOp, msg: "Could not convert funcop");
495
496 rewriter.eraseOp(op: funcOp);
497 return success();
498 }
499};
500
501struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
502 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
503
504 LogicalResult
505 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
506 ConversionPatternRewriter &rewriter) const override {
507 auto type = typeConverter->convertType(t: op.getResult().getType());
508 if (!type || !LLVM::isCompatibleType(type))
509 return rewriter.notifyMatchFailure(arg&: op, msg: "failed to convert result type");
510
511 auto newOp =
512 rewriter.create<LLVM::AddressOfOp>(location: op.getLoc(), args&: type, args: op.getValue());
513 for (const NamedAttribute &attr : op->getAttrs()) {
514 if (attr.getName().strref() == "value")
515 continue;
516 newOp->setAttr(name: attr.getName(), value: attr.getValue());
517 }
518 rewriter.replaceOp(op, newValues: newOp->getResults());
519 return success();
520 }
521};
522
523// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
524// passes the pointer to the MemRef across function boundaries.
525template <typename CallOpType>
526struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
527 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
528 using Super = CallOpInterfaceLowering<CallOpType>;
529 using Base = ConvertOpToLLVMPattern<CallOpType>;
530
531 LogicalResult matchAndRewriteImpl(CallOpType callOp,
532 typename CallOpType::Adaptor adaptor,
533 ConversionPatternRewriter &rewriter,
534 bool useBarePtrCallConv = false) const {
535 // Pack the result types into a struct.
536 Type packedResult = nullptr;
537 unsigned numResults = callOp.getNumResults();
538 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
539
540 if (numResults != 0) {
541 if (!(packedResult = this->getTypeConverter()->packFunctionResults(
542 resultTypes, useBarePtrCallConv)))
543 return failure();
544 }
545
546 if (useBarePtrCallConv) {
547 for (auto it : callOp->getOperands()) {
548 Type operandType = it.getType();
549 if (isa<UnrankedMemRefType>(Val: operandType)) {
550 // Unranked memref is not supported in the bare pointer calling
551 // convention.
552 return failure();
553 }
554 }
555 }
556 auto promoted = this->getTypeConverter()->promoteOperands(
557 callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
558 adaptor.getOperands(), rewriter, useBarePtrCallConv);
559 auto newOp = rewriter.create<LLVM::CallOp>(
560 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
561 promoted, callOp->getAttrs());
562
563 newOp.getProperties().operandSegmentSizes = {
564 static_cast<int32_t>(promoted.size()), 0};
565 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr(values: {});
566
567 SmallVector<Value, 4> results;
568 if (numResults < 2) {
569 // If < 2 results, packing did not do anything and we can just return.
570 results.append(newOp.result_begin(), newOp.result_end());
571 } else {
572 // Otherwise, it had been converted to an operation producing a structure.
573 // Extract individual results from the structure and return them as list.
574 results.reserve(N: numResults);
575 for (unsigned i = 0; i < numResults; ++i) {
576 results.push_back(Elt: rewriter.create<LLVM::ExtractValueOp>(
577 callOp.getLoc(), newOp->getResult(0), i));
578 }
579 }
580
581 if (useBarePtrCallConv) {
582 // For the bare-ptr calling convention, promote memref results to
583 // descriptors.
584 assert(results.size() == resultTypes.size() &&
585 "The number of arguments and types doesn't match");
586 this->getTypeConverter()->promoteBarePtrsToDescriptors(
587 rewriter, callOp.getLoc(), resultTypes, results);
588 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
589 resultTypes, results,
590 /*toDynamic=*/false))) {
591 return failure();
592 }
593
594 rewriter.replaceOp(callOp, results);
595 return success();
596 }
597};
598
599class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
600public:
601 explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
602 SymbolTableCollection *symbolTables = nullptr,
603 PatternBenefit benefit = 1)
604 : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
605 symbolTables(symbolTables) {}
606
607 LogicalResult
608 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter) const override {
610 bool useBarePtrCallConv = false;
611 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
612 useBarePtrCallConv = true;
613 } else if (symbolTables != nullptr) {
614 // Fast lookup.
615 Operation *callee =
616 symbolTables->lookupNearestSymbolFrom(from: callOp, symbol: callOp.getCalleeAttr());
617 useBarePtrCallConv =
618 callee != nullptr && callee->hasAttr(name: barePtrAttrName);
619 } else {
620 // Warning: This is a linear lookup.
621 Operation *callee =
622 SymbolTable::lookupNearestSymbolFrom(from: callOp, symbol: callOp.getCalleeAttr());
623 useBarePtrCallConv =
624 callee != nullptr && callee->hasAttr(name: barePtrAttrName);
625 }
626 return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
627 }
628
629private:
630 SymbolTableCollection *symbolTables = nullptr;
631};
632
633struct CallIndirectOpLowering
634 : public CallOpInterfaceLowering<func::CallIndirectOp> {
635 using Super::Super;
636
637 LogicalResult
638 matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
639 ConversionPatternRewriter &rewriter) const override {
640 return matchAndRewriteImpl(callOp: callIndirectOp, adaptor, rewriter);
641 }
642};
643
644struct UnrealizedConversionCastOpLowering
645 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
646 using ConvertOpToLLVMPattern<
647 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
648
649 LogicalResult
650 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
651 ConversionPatternRewriter &rewriter) const override {
652 SmallVector<Type> convertedTypes;
653 if (succeeded(Result: typeConverter->convertTypes(types: op.getOutputs().getTypes(),
654 results&: convertedTypes)) &&
655 convertedTypes == adaptor.getInputs().getTypes()) {
656 rewriter.replaceOp(op, newValues: adaptor.getInputs());
657 return success();
658 }
659
660 convertedTypes.clear();
661 if (succeeded(Result: typeConverter->convertTypes(types: adaptor.getInputs().getTypes(),
662 results&: convertedTypes)) &&
663 convertedTypes == op.getOutputs().getType()) {
664 rewriter.replaceOp(op, newValues: adaptor.getInputs());
665 return success();
666 }
667 return failure();
668 }
669};
670
671// Special lowering pattern for `ReturnOps`. Unlike all other operations,
672// `ReturnOp` interacts with the function signature and must have as many
673// operands as the function has return values. Because in LLVM IR, functions
674// can only return 0 or 1 value, we pack multiple values into a structure type.
675// Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
676// necessary before returning it
677struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
678 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
679
680 LogicalResult
681 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
682 ConversionPatternRewriter &rewriter) const override {
683 Location loc = op.getLoc();
684 unsigned numArguments = op.getNumOperands();
685 SmallVector<Value, 4> updatedOperands;
686
687 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
688 bool useBarePtrCallConv =
689 shouldUseBarePtrCallConv(op: funcOp, typeConverter: this->getTypeConverter());
690 if (useBarePtrCallConv) {
691 // For the bare-ptr calling convention, extract the aligned pointer to
692 // be returned from the memref descriptor.
693 for (auto it : llvm::zip(t: op->getOperands(), u: adaptor.getOperands())) {
694 Type oldTy = std::get<0>(t&: it).getType();
695 Value newOperand = std::get<1>(t&: it);
696 if (isa<MemRefType>(Val: oldTy) && getTypeConverter()->canConvertToBarePtr(
697 type: cast<BaseMemRefType>(Val&: oldTy))) {
698 MemRefDescriptor memrefDesc(newOperand);
699 newOperand = memrefDesc.allocatedPtr(builder&: rewriter, loc);
700 } else if (isa<UnrankedMemRefType>(Val: oldTy)) {
701 // Unranked memref is not supported in the bare pointer calling
702 // convention.
703 return failure();
704 }
705 updatedOperands.push_back(Elt: newOperand);
706 }
707 } else {
708 updatedOperands = llvm::to_vector<4>(Range: adaptor.getOperands());
709 (void)copyUnrankedDescriptors(builder&: rewriter, loc, origTypes: op.getOperands().getTypes(),
710 operands&: updatedOperands,
711 /*toDynamic=*/true);
712 }
713
714 // If ReturnOp has 0 or 1 operand, create it and return immediately.
715 if (numArguments <= 1) {
716 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
717 op, args: TypeRange(), args&: updatedOperands, args: op->getAttrs());
718 return success();
719 }
720
721 // Otherwise, we need to pack the arguments into an LLVM struct type before
722 // returning.
723 auto packedType = getTypeConverter()->packFunctionResults(
724 types: op.getOperandTypes(), useBarePointerCallConv: useBarePtrCallConv);
725 if (!packedType) {
726 return rewriter.notifyMatchFailure(arg&: op, msg: "could not convert result types");
727 }
728
729 Value packed = rewriter.create<LLVM::PoisonOp>(location: loc, args&: packedType);
730 for (auto [idx, operand] : llvm::enumerate(First&: updatedOperands)) {
731 packed = rewriter.create<LLVM::InsertValueOp>(location: loc, args&: packed, args&: operand, args&: idx);
732 }
733 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, args: TypeRange(), args&: packed,
734 args: op->getAttrs());
735 return success();
736 }
737};
738} // namespace
739
740void mlir::populateFuncToLLVMFuncOpConversionPattern(
741 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
742 SymbolTableCollection *symbolTables) {
743 patterns.add<FuncOpConversion>(arg: converter, args&: symbolTables);
744}
745
746void mlir::populateFuncToLLVMConversionPatterns(
747 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
748 SymbolTableCollection *symbolTables) {
749 populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
750 patterns.add<CallIndirectOpLowering>(arg: converter);
751 patterns.add<CallOpLowering>(arg: converter, args&: symbolTables);
752 patterns.add<ConstantOpLowering>(arg: converter);
753 patterns.add<ReturnOpLowering>(arg: converter);
754}
755
756namespace {
757/// A pass converting Func operations into the LLVM IR dialect.
758struct ConvertFuncToLLVMPass
759 : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
760 using Base::Base;
761
762 /// Run the dialect converter on the module.
763 void runOnOperation() override {
764 ModuleOp m = getOperation();
765 StringRef dataLayout;
766 auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
767 Val: m->getAttr(name: LLVM::LLVMDialect::getDataLayoutAttrName()));
768 if (dataLayoutAttr)
769 dataLayout = dataLayoutAttr.getValue();
770
771 if (failed(Result: LLVM::LLVMDialect::verifyDataLayoutString(
772 descr: dataLayout, reportError: [this](const Twine &message) {
773 getOperation().emitError() << message.str();
774 }))) {
775 signalPassFailure();
776 return;
777 }
778
779 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
780
781 LowerToLLVMOptions options(&getContext(),
782 dataLayoutAnalysis.getAtOrAbove(operation: m));
783 options.useBarePtrCallConv = useBarePtrCallConv;
784 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
785 options.overrideIndexBitwidth(bitwidth: indexBitwidth);
786 options.dataLayout = llvm::DataLayout(dataLayout);
787
788 LLVMTypeConverter typeConverter(&getContext(), options,
789 &dataLayoutAnalysis);
790
791 RewritePatternSet patterns(&getContext());
792 SymbolTableCollection symbolTables;
793
794 populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns,
795 symbolTables: &symbolTables);
796
797 LLVMConversionTarget target(getContext());
798 if (failed(Result: applyPartialConversion(op: m, target, patterns: std::move(patterns))))
799 signalPassFailure();
800 }
801};
802
803struct SetLLVMModuleDataLayoutPass
804 : public impl::SetLLVMModuleDataLayoutPassBase<
805 SetLLVMModuleDataLayoutPass> {
806 using Base::Base;
807
808 /// Run the dialect converter on the module.
809 void runOnOperation() override {
810 if (failed(Result: LLVM::LLVMDialect::verifyDataLayoutString(
811 descr: this->dataLayout, reportError: [this](const Twine &message) {
812 getOperation().emitError() << message.str();
813 }))) {
814 signalPassFailure();
815 return;
816 }
817 ModuleOp m = getOperation();
818 m->setAttr(name: LLVM::LLVMDialect::getDataLayoutAttrName(),
819 value: StringAttr::get(context: m.getContext(), bytes: this->dataLayout));
820 }
821};
822} // namespace
823
824//===----------------------------------------------------------------------===//
825// ConvertToLLVMPatternInterface implementation
826//===----------------------------------------------------------------------===//
827
828namespace {
829/// Implement the interface to convert Func to LLVM.
830struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
831 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
832 /// Hook for derived dialect interface to provide conversion patterns
833 /// and mark dialect legal for the conversion target.
834 void populateConvertToLLVMConversionPatterns(
835 ConversionTarget &target, LLVMTypeConverter &typeConverter,
836 RewritePatternSet &patterns) const final {
837 populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns);
838 }
839};
840} // namespace
841
842void mlir::registerConvertFuncToLLVMInterface(DialectRegistry &registry) {
843 registry.addExtension(extensionFn: +[](MLIRContext *ctx, func::FuncDialect *dialect) {
844 dialect->addInterfaces<FuncToLLVMDialectInterface>();
845 });
846}
847

source code of mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp