1//===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR Func and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
15
16#include "mlir/Analysis/DataLayoutAnalysis.h"
17#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
18#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
19#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
22#include "mlir/Conversion/LLVMCommon/Pattern.h"
23#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h"
25#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
26#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
28#include "mlir/Dialect/Utils/StaticValueUtils.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinAttributeInterfaces.h"
32#include "mlir/IR/BuiltinAttributes.h"
33#include "mlir/IR/BuiltinOps.h"
34#include "mlir/IR/IRMapping.h"
35#include "mlir/IR/PatternMatch.h"
36#include "mlir/IR/SymbolTable.h"
37#include "mlir/IR/TypeUtilities.h"
38#include "mlir/Support/LogicalResult.h"
39#include "mlir/Support/MathExtras.h"
40#include "mlir/Transforms/DialectConversion.h"
41#include "mlir/Transforms/Passes.h"
42#include "llvm/ADT/SmallVector.h"
43#include "llvm/ADT/TypeSwitch.h"
44#include "llvm/IR/DerivedTypes.h"
45#include "llvm/IR/IRBuilder.h"
46#include "llvm/IR/Type.h"
47#include "llvm/Support/Casting.h"
48#include "llvm/Support/CommandLine.h"
49#include "llvm/Support/FormatVariadic.h"
50#include <algorithm>
51#include <functional>
52#include <optional>
53
54namespace mlir {
55#define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
56#define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
57#include "mlir/Conversion/Passes.h.inc"
58} // namespace mlir
59
60using namespace mlir;
61
62#define PASS_NAME "convert-func-to-llvm"
63
64static constexpr StringRef varargsAttrName = "func.varargs";
65static constexpr StringRef linkageAttrName = "llvm.linkage";
66static constexpr StringRef barePtrAttrName = "llvm.bareptr";
67
68/// Return `true` if the `op` should use bare pointer calling convention.
69static bool shouldUseBarePtrCallConv(Operation *op,
70 const LLVMTypeConverter *typeConverter) {
71 return (op && op->hasAttr(name: barePtrAttrName)) ||
72 typeConverter->getOptions().useBarePtrCallConv;
73}
74
75/// Only retain those attributes that are not constructed by
76/// `LLVMFuncOp::build`.
77static void filterFuncAttributes(FunctionOpInterface func,
78 SmallVectorImpl<NamedAttribute> &result) {
79 for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
80 if (attr.getName() == linkageAttrName ||
81 attr.getName() == varargsAttrName ||
82 attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
83 continue;
84 result.push_back(attr);
85 }
86}
87
88/// Propagate argument/results attributes.
89static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
90 FunctionOpInterface funcOp,
91 LLVM::LLVMFuncOp wrapperFuncOp) {
92 auto argAttrs = funcOp.getAllArgAttrs();
93 if (!resultStructType) {
94 if (auto resAttrs = funcOp.getAllResultAttrs())
95 wrapperFuncOp.setAllResultAttrs(resAttrs);
96 if (argAttrs)
97 wrapperFuncOp.setAllArgAttrs(argAttrs);
98 } else {
99 SmallVector<Attribute> argAttributes;
100 // Only modify the argument and result attributes when the result is now
101 // an argument.
102 if (argAttrs) {
103 argAttributes.push_back(builder.getDictionaryAttr({}));
104 argAttributes.append(argAttrs.begin(), argAttrs.end());
105 wrapperFuncOp.setAllArgAttrs(argAttributes);
106 }
107 }
108 cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
109 .setVisibility(funcOp.getVisibility());
110}
111
112/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
113/// arguments instead of unpacked arguments. This function can be called from C
114/// by passing a pointer to a C struct corresponding to a memref descriptor.
115/// Similarly, returned memrefs are passed via pointers to a C struct that is
116/// passed as additional argument.
117/// Internally, the auxiliary function unpacks the descriptor into individual
118/// components and forwards them to `newFuncOp` and forwards the results to
119/// the extra arguments.
120static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
121 const LLVMTypeConverter &typeConverter,
122 FunctionOpInterface funcOp,
123 LLVM::LLVMFuncOp newFuncOp) {
124 auto type = cast<FunctionType>(funcOp.getFunctionType());
125 auto [wrapperFuncType, resultStructType] =
126 typeConverter.convertFunctionTypeCWrapper(type: type);
127
128 SmallVector<NamedAttribute> attributes;
129 filterFuncAttributes(funcOp, attributes);
130
131 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
132 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
133 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
134 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
135 propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
136
137 OpBuilder::InsertionGuard guard(rewriter);
138 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
139
140 SmallVector<Value, 8> args;
141 size_t argOffset = resultStructType ? 1 : 0;
142 for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
143 Value arg = wrapperFuncOp.getArgument(index + argOffset);
144 if (auto memrefType = dyn_cast<MemRefType>(argType)) {
145 Value loaded = rewriter.create<LLVM::LoadOp>(
146 loc, typeConverter.convertType(memrefType), arg);
147 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
148 continue;
149 }
150 if (isa<UnrankedMemRefType>(argType)) {
151 Value loaded = rewriter.create<LLVM::LoadOp>(
152 loc, typeConverter.convertType(argType), arg);
153 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
154 continue;
155 }
156
157 args.push_back(arg);
158 }
159
160 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
161
162 if (resultStructType) {
163 rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
164 wrapperFuncOp.getArgument(0));
165 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
166 } else {
167 rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
168 }
169}
170
171/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
172/// arguments instead of unpacked arguments. Creates a body for the (external)
173/// `newFuncOp` that allocates a memref descriptor on stack, packs the
174/// individual arguments into this descriptor and passes a pointer to it into
175/// the auxiliary function. If the result of the function cannot be directly
176/// returned, we write it to a special first argument that provides a pointer
177/// to a corresponding struct. This auxiliary external function is now
178/// compatible with functions defined in C using pointers to C structs
179/// corresponding to a memref descriptor.
180static void wrapExternalFunction(OpBuilder &builder, Location loc,
181 const LLVMTypeConverter &typeConverter,
182 FunctionOpInterface funcOp,
183 LLVM::LLVMFuncOp newFuncOp) {
184 OpBuilder::InsertionGuard guard(builder);
185
186 auto [wrapperType, resultStructType] =
187 typeConverter.convertFunctionTypeCWrapper(
188 type: cast<FunctionType>(funcOp.getFunctionType()));
189 // This conversion can only fail if it could not convert one of the argument
190 // types. But since it has been applied to a non-wrapper function before, it
191 // should have failed earlier and not reach this point at all.
192 assert(wrapperType && "unexpected type conversion failure");
193
194 SmallVector<NamedAttribute, 4> attributes;
195 filterFuncAttributes(funcOp, attributes);
196
197 // Create the auxiliary function.
198 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
199 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
200 wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
201 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
202 propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
203
204 // The wrapper that we synthetize here should only be visible in this module.
205 newFuncOp.setLinkage(LLVM::Linkage::Private);
206 builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
207
208 // Get a ValueRange containing arguments.
209 FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
210 SmallVector<Value, 8> args;
211 args.reserve(N: type.getNumInputs());
212 ValueRange wrapperArgsRange(newFuncOp.getArguments());
213
214 if (resultStructType) {
215 // Allocate the struct on the stack and pass the pointer.
216 Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
217 Value one = builder.create<LLVM::ConstantOp>(
218 loc, typeConverter.convertType(builder.getIndexType()),
219 builder.getIntegerAttr(builder.getIndexType(), 1));
220 Value result =
221 builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
222 args.push_back(Elt: result);
223 }
224
225 // Iterate over the inputs of the original function and pack values into
226 // memref descriptors if the original type is a memref.
227 for (Type input : type.getInputs()) {
228 Value arg;
229 int numToDrop = 1;
230 auto memRefType = dyn_cast<MemRefType>(input);
231 auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
232 if (memRefType || unrankedMemRefType) {
233 numToDrop = memRefType
234 ? MemRefDescriptor::getNumUnpackedValues(memRefType)
235 : UnrankedMemRefDescriptor::getNumUnpackedValues();
236 Value packed =
237 memRefType
238 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
239 wrapperArgsRange.take_front(numToDrop))
240 : UnrankedMemRefDescriptor::pack(
241 builder, loc, typeConverter, unrankedMemRefType,
242 wrapperArgsRange.take_front(numToDrop));
243
244 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
245 Value one = builder.create<LLVM::ConstantOp>(
246 loc, typeConverter.convertType(builder.getIndexType()),
247 builder.getIntegerAttr(builder.getIndexType(), 1));
248 Value allocated = builder.create<LLVM::AllocaOp>(
249 loc, ptrTy, packed.getType(), one, /*alignment=*/0);
250 builder.create<LLVM::StoreOp>(loc, packed, allocated);
251 arg = allocated;
252 } else {
253 arg = wrapperArgsRange[0];
254 }
255
256 args.push_back(arg);
257 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
258 }
259 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
260
261 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
262
263 if (resultStructType) {
264 Value result =
265 builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
266 builder.create<LLVM::ReturnOp>(loc, result);
267 } else {
268 builder.create<LLVM::ReturnOp>(loc, call.getResults());
269 }
270}
271
272/// Modifies the body of the function to construct the `MemRefDescriptor` from
273/// the bare pointer calling convention lowering of `memref` types.
274static void modifyFuncOpToUseBarePtrCallingConv(
275 ConversionPatternRewriter &rewriter, Location loc,
276 const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
277 TypeRange oldArgTypes) {
278 if (funcOp.getBody().empty())
279 return;
280
281 // Promote bare pointers from memref arguments to memref descriptors at the
282 // beginning of the function so that all the memrefs in the function have a
283 // uniform representation.
284 Block *entryBlock = &funcOp.getBody().front();
285 auto blockArgs = entryBlock->getArguments();
286 assert(blockArgs.size() == oldArgTypes.size() &&
287 "The number of arguments and types doesn't match");
288
289 OpBuilder::InsertionGuard guard(rewriter);
290 rewriter.setInsertionPointToStart(entryBlock);
291 for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
292 BlockArgument arg = std::get<0>(it);
293 Type argTy = std::get<1>(it);
294
295 // Unranked memrefs are not supported in the bare pointer calling
296 // convention. We should have bailed out before in the presence of
297 // unranked memrefs.
298 assert(!isa<UnrankedMemRefType>(argTy) &&
299 "Unranked memref is not supported");
300 auto memrefTy = dyn_cast<MemRefType>(argTy);
301 if (!memrefTy)
302 continue;
303
304 // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
305 // or unranked memref descriptor and replace placeholder with the last
306 // instruction of the memref descriptor.
307 // TODO: The placeholder is needed to avoid replacing barePtr uses in the
308 // MemRef descriptor instructions. We may want to have a utility in the
309 // rewriter to properly handle this use case.
310 Location loc = funcOp.getLoc();
311 auto placeholder = rewriter.create<LLVM::UndefOp>(
312 loc, typeConverter.convertType(memrefTy));
313 rewriter.replaceUsesOfBlockArgument(arg, placeholder);
314
315 Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
316 memrefTy, arg);
317 rewriter.replaceOp(placeholder, {desc});
318 }
319}
320
321FailureOr<LLVM::LLVMFuncOp>
322mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
323 ConversionPatternRewriter &rewriter,
324 const LLVMTypeConverter &converter) {
325 // Check the funcOp has `FunctionType`.
326 auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
327 if (!funcTy)
328 return rewriter.notifyMatchFailure(
329 funcOp, "Only support FunctionOpInterface with FunctionType");
330
331 // Convert the original function arguments. They are converted using the
332 // LLVMTypeConverter provided to this legalization pattern.
333 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
334 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
335 auto llvmType = converter.convertFunctionSignature(
336 funcTy: funcTy, isVariadic: varargsAttr && varargsAttr.getValue(),
337 useBarePtrCallConv: shouldUseBarePtrCallConv(funcOp, &converter), result);
338 if (!llvmType)
339 return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
340
341 // Create an LLVM function, use external linkage by default until MLIR
342 // functions have linkage.
343 LLVM::Linkage linkage = LLVM::Linkage::External;
344 if (funcOp->hasAttr(linkageAttrName)) {
345 auto attr =
346 dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
347 if (!attr) {
348 funcOp->emitError() << "Contains " << linkageAttrName
349 << " attribute not of type LLVM::LinkageAttr";
350 return rewriter.notifyMatchFailure(
351 funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
352 }
353 linkage = attr.getLinkage();
354 }
355
356 SmallVector<NamedAttribute, 4> attributes;
357 filterFuncAttributes(funcOp, attributes);
358 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
359 funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
360 /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
361 attributes);
362 cast<FunctionOpInterface>(newFuncOp.getOperation())
363 .setVisibility(funcOp.getVisibility());
364
365 // Create a memory effect attribute corresponding to readnone.
366 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
367 if (funcOp->hasAttr(readnoneAttrName)) {
368 auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
369 if (!attr) {
370 funcOp->emitError() << "Contains " << readnoneAttrName
371 << " attribute not of type UnitAttr";
372 return rewriter.notifyMatchFailure(
373 funcOp, "Contains readnone attribute not of type UnitAttr");
374 }
375 auto memoryAttr = LLVM::MemoryEffectsAttr::get(
376 rewriter.getContext(),
377 {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
378 LLVM::ModRefInfo::NoModRef});
379 newFuncOp.setMemoryAttr(memoryAttr);
380 }
381
382 // Propagate argument/result attributes to all converted arguments/result
383 // obtained after converting a given original argument/result.
384 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
385 assert(!resAttrDicts.empty() && "expected array to be non-empty");
386 if (funcOp.getNumResults() == 1)
387 newFuncOp.setAllResultAttrs(resAttrDicts);
388 }
389 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
390 SmallVector<Attribute> newArgAttrs(
391 cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
392 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
393 // Some LLVM IR attribute have a type attached to them. During FuncOp ->
394 // LLVMFuncOp conversion these types may have changed. Account for that
395 // change by converting attributes' types as well.
396 SmallVector<NamedAttribute, 4> convertedAttrs;
397 auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
398 convertedAttrs.reserve(N: attrsDict.size());
399 for (const NamedAttribute &attr : attrsDict) {
400 const auto convert = [&](const NamedAttribute &attr) {
401 return TypeAttr::get(converter.convertType(
402 cast<TypeAttr>(attr.getValue()).getValue()));
403 };
404 if (attr.getName().getValue() ==
405 LLVM::LLVMDialect::getByValAttrName()) {
406 convertedAttrs.push_back(rewriter.getNamedAttr(
407 LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
408 } else if (attr.getName().getValue() ==
409 LLVM::LLVMDialect::getByRefAttrName()) {
410 convertedAttrs.push_back(rewriter.getNamedAttr(
411 LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
412 } else if (attr.getName().getValue() ==
413 LLVM::LLVMDialect::getStructRetAttrName()) {
414 convertedAttrs.push_back(rewriter.getNamedAttr(
415 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
416 } else if (attr.getName().getValue() ==
417 LLVM::LLVMDialect::getInAllocaAttrName()) {
418 convertedAttrs.push_back(rewriter.getNamedAttr(
419 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
420 } else {
421 convertedAttrs.push_back(attr);
422 }
423 }
424 auto mapping = result.getInputMapping(input: i);
425 assert(mapping && "unexpected deletion of function argument");
426 // Only attach the new argument attributes if there is a one-to-one
427 // mapping from old to new types. Otherwise, attributes might be
428 // attached to types that they do not support.
429 if (mapping->size == 1) {
430 newArgAttrs[mapping->inputNo] =
431 DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
432 continue;
433 }
434 // TODO: Implement custom handling for types that expand to multiple
435 // function arguments.
436 for (size_t j = 0; j < mapping->size; ++j)
437 newArgAttrs[mapping->inputNo + j] =
438 DictionaryAttr::get(rewriter.getContext(), {});
439 }
440 if (!newArgAttrs.empty())
441 newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
442 }
443
444 rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
445 newFuncOp.end());
446 if (failed(rewriter.convertRegionTypes(region: &newFuncOp.getBody(), converter,
447 entryConversion: &result))) {
448 return rewriter.notifyMatchFailure(funcOp,
449 "region types conversion failed");
450 }
451
452 return newFuncOp;
453}
454
455namespace {
456
457struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
458protected:
459 using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
460
461 // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
462 // to this legalization pattern.
463 FailureOr<LLVM::LLVMFuncOp>
464 convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
465 ConversionPatternRewriter &rewriter) const {
466 return mlir::convertFuncOpToLLVMFuncOp(
467 cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
468 *getTypeConverter());
469 }
470};
471
472/// FuncOp legalization pattern that converts MemRef arguments to pointers to
473/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
474/// information.
475struct FuncOpConversion : public FuncOpConversionBase {
476 FuncOpConversion(const LLVMTypeConverter &converter)
477 : FuncOpConversionBase(converter) {}
478
479 LogicalResult
480 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
481 ConversionPatternRewriter &rewriter) const override {
482 FailureOr<LLVM::LLVMFuncOp> newFuncOp =
483 convertFuncOpToLLVMFuncOp(funcOp, rewriter);
484 if (failed(result: newFuncOp))
485 return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
486
487 if (!shouldUseBarePtrCallConv(funcOp, this->getTypeConverter())) {
488 if (funcOp->getAttrOfType<UnitAttr>(
489 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
490 if (newFuncOp->isVarArg())
491 return funcOp->emitError("C interface for variadic functions is not "
492 "supported yet.");
493
494 if (newFuncOp->isExternal())
495 wrapExternalFunction(rewriter, funcOp->getLoc(), *getTypeConverter(),
496 funcOp, *newFuncOp);
497 else
498 wrapForExternalCallers(rewriter, funcOp->getLoc(),
499 *getTypeConverter(), funcOp, *newFuncOp);
500 }
501 } else {
502 modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp->getLoc(),
503 *getTypeConverter(), *newFuncOp,
504 funcOp.getFunctionType().getInputs());
505 }
506
507 rewriter.eraseOp(op: funcOp);
508 return success();
509 }
510};
511
512struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
513 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
514
515 LogicalResult
516 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
517 ConversionPatternRewriter &rewriter) const override {
518 auto type = typeConverter->convertType(op.getResult().getType());
519 if (!type || !LLVM::isCompatibleType(type: type))
520 return rewriter.notifyMatchFailure(op, "failed to convert result type");
521
522 auto newOp =
523 rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
524 for (const NamedAttribute &attr : op->getAttrs()) {
525 if (attr.getName().strref() == "value")
526 continue;
527 newOp->setAttr(attr.getName(), attr.getValue());
528 }
529 rewriter.replaceOp(op, newOp->getResults());
530 return success();
531 }
532};
533
534// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
535// passes the pointer to the MemRef across function boundaries.
536template <typename CallOpType>
537struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
538 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
539 using Super = CallOpInterfaceLowering<CallOpType>;
540 using Base = ConvertOpToLLVMPattern<CallOpType>;
541
542 LogicalResult matchAndRewriteImpl(CallOpType callOp,
543 typename CallOpType::Adaptor adaptor,
544 ConversionPatternRewriter &rewriter,
545 bool useBarePtrCallConv = false) const {
546 // Pack the result types into a struct.
547 Type packedResult = nullptr;
548 unsigned numResults = callOp.getNumResults();
549 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
550
551 if (numResults != 0) {
552 if (!(packedResult = this->getTypeConverter()->packFunctionResults(
553 resultTypes, useBarePtrCallConv)))
554 return failure();
555 }
556
557 if (useBarePtrCallConv) {
558 for (auto it : callOp->getOperands()) {
559 Type operandType = it.getType();
560 if (isa<UnrankedMemRefType>(Val: operandType)) {
561 // Unranked memref is not supported in the bare pointer calling
562 // convention.
563 return failure();
564 }
565 }
566 }
567 auto promoted = this->getTypeConverter()->promoteOperands(
568 callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
569 adaptor.getOperands(), rewriter, useBarePtrCallConv);
570 auto newOp = rewriter.create<LLVM::CallOp>(
571 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
572 promoted, callOp->getAttrs());
573
574 SmallVector<Value, 4> results;
575 if (numResults < 2) {
576 // If < 2 results, packing did not do anything and we can just return.
577 results.append(newOp.result_begin(), newOp.result_end());
578 } else {
579 // Otherwise, it had been converted to an operation producing a structure.
580 // Extract individual results from the structure and return them as list.
581 results.reserve(N: numResults);
582 for (unsigned i = 0; i < numResults; ++i) {
583 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
584 callOp.getLoc(), newOp->getResult(0), i));
585 }
586 }
587
588 if (useBarePtrCallConv) {
589 // For the bare-ptr calling convention, promote memref results to
590 // descriptors.
591 assert(results.size() == resultTypes.size() &&
592 "The number of arguments and types doesn't match");
593 this->getTypeConverter()->promoteBarePtrsToDescriptors(
594 rewriter, callOp.getLoc(), resultTypes, results);
595 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
596 resultTypes, results,
597 /*toDynamic=*/false))) {
598 return failure();
599 }
600
601 rewriter.replaceOp(callOp, results);
602 return success();
603 }
604};
605
606class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
607public:
608 CallOpLowering(const LLVMTypeConverter &typeConverter,
609 // Can be nullptr.
610 const SymbolTable *symbolTable, PatternBenefit benefit = 1)
611 : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
612 symbolTable(symbolTable) {}
613
614 LogicalResult
615 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
616 ConversionPatternRewriter &rewriter) const override {
617 bool useBarePtrCallConv = false;
618 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
619 useBarePtrCallConv = true;
620 } else if (symbolTable != nullptr) {
621 // Fast lookup.
622 Operation *callee =
623 symbolTable->lookup(callOp.getCalleeAttr().getValue());
624 useBarePtrCallConv =
625 callee != nullptr && callee->hasAttr(name: barePtrAttrName);
626 } else {
627 // Warning: This is a linear lookup.
628 Operation *callee =
629 SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
630 useBarePtrCallConv =
631 callee != nullptr && callee->hasAttr(name: barePtrAttrName);
632 }
633 return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
634 }
635
636private:
637 const SymbolTable *symbolTable = nullptr;
638};
639
640struct CallIndirectOpLowering
641 : public CallOpInterfaceLowering<func::CallIndirectOp> {
642 using Super::Super;
643
644 LogicalResult
645 matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
646 ConversionPatternRewriter &rewriter) const override {
647 return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
648 }
649};
650
651struct UnrealizedConversionCastOpLowering
652 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
653 using ConvertOpToLLVMPattern<
654 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
655
656 LogicalResult
657 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
658 ConversionPatternRewriter &rewriter) const override {
659 SmallVector<Type> convertedTypes;
660 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
661 convertedTypes)) &&
662 convertedTypes == adaptor.getInputs().getTypes()) {
663 rewriter.replaceOp(op, adaptor.getInputs());
664 return success();
665 }
666
667 convertedTypes.clear();
668 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
669 convertedTypes)) &&
670 convertedTypes == op.getOutputs().getType()) {
671 rewriter.replaceOp(op, adaptor.getInputs());
672 return success();
673 }
674 return failure();
675 }
676};
677
678// Special lowering pattern for `ReturnOps`. Unlike all other operations,
679// `ReturnOp` interacts with the function signature and must have as many
680// operands as the function has return values. Because in LLVM IR, functions
681// can only return 0 or 1 value, we pack multiple values into a structure type.
682// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
683// necessary before returning it
684struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
685 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
686
687 LogicalResult
688 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
689 ConversionPatternRewriter &rewriter) const override {
690 Location loc = op.getLoc();
691 unsigned numArguments = op.getNumOperands();
692 SmallVector<Value, 4> updatedOperands;
693
694 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
695 bool useBarePtrCallConv =
696 shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
697 if (useBarePtrCallConv) {
698 // For the bare-ptr calling convention, extract the aligned pointer to
699 // be returned from the memref descriptor.
700 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
701 Type oldTy = std::get<0>(it).getType();
702 Value newOperand = std::get<1>(it);
703 if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
704 cast<BaseMemRefType>(oldTy))) {
705 MemRefDescriptor memrefDesc(newOperand);
706 newOperand = memrefDesc.allocatedPtr(rewriter, loc);
707 } else if (isa<UnrankedMemRefType>(oldTy)) {
708 // Unranked memref is not supported in the bare pointer calling
709 // convention.
710 return failure();
711 }
712 updatedOperands.push_back(newOperand);
713 }
714 } else {
715 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
716 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
717 updatedOperands,
718 /*toDynamic=*/true);
719 }
720
721 // If ReturnOp has 0 or 1 operand, create it and return immediately.
722 if (numArguments <= 1) {
723 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
724 op, TypeRange(), updatedOperands, op->getAttrs());
725 return success();
726 }
727
728 // Otherwise, we need to pack the arguments into an LLVM struct type before
729 // returning.
730 auto packedType = getTypeConverter()->packFunctionResults(
731 op.getOperandTypes(), useBarePtrCallConv);
732 if (!packedType) {
733 return rewriter.notifyMatchFailure(op, "could not convert result types");
734 }
735
736 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
737 for (auto [idx, operand] : llvm::enumerate(First&: updatedOperands)) {
738 packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
739 }
740 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
741 op->getAttrs());
742 return success();
743 }
744};
745} // namespace
746
747void mlir::populateFuncToLLVMFuncOpConversionPattern(
748 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
749 patterns.add<FuncOpConversion>(arg&: converter);
750}
751
752void mlir::populateFuncToLLVMConversionPatterns(
753 LLVMTypeConverter &converter, RewritePatternSet &patterns,
754 const SymbolTable *symbolTable) {
755 populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
756 patterns.add<CallIndirectOpLowering>(arg&: converter);
757 patterns.add<CallOpLowering>(arg&: converter, args&: symbolTable);
758 patterns.add<ConstantOpLowering>(arg&: converter);
759 patterns.add<ReturnOpLowering>(arg&: converter);
760}
761
762namespace {
763/// A pass converting Func operations into the LLVM IR dialect.
764struct ConvertFuncToLLVMPass
765 : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
766 using Base::Base;
767
768 /// Run the dialect converter on the module.
769 void runOnOperation() override {
770 ModuleOp m = getOperation();
771 StringRef dataLayout;
772 auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
773 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
774 if (dataLayoutAttr)
775 dataLayout = dataLayoutAttr.getValue();
776
777 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
778 dataLayout, [this](const Twine &message) {
779 getOperation().emitError() << message.str();
780 }))) {
781 signalPassFailure();
782 return;
783 }
784
785 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
786
787 LowerToLLVMOptions options(&getContext(),
788 dataLayoutAnalysis.getAtOrAbove(m));
789 options.useBarePtrCallConv = useBarePtrCallConv;
790 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
791 options.overrideIndexBitwidth(indexBitwidth);
792 options.dataLayout = llvm::DataLayout(dataLayout);
793
794 LLVMTypeConverter typeConverter(&getContext(), options,
795 &dataLayoutAnalysis);
796
797 std::optional<SymbolTable> optSymbolTable = std::nullopt;
798 const SymbolTable *symbolTable = nullptr;
799 if (!options.useBarePtrCallConv) {
800 optSymbolTable.emplace(m);
801 symbolTable = &optSymbolTable.value();
802 }
803
804 RewritePatternSet patterns(&getContext());
805 populateFuncToLLVMConversionPatterns(converter&: typeConverter, patterns, symbolTable);
806
807 // TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
808 // favor of their dedicated conversion passes.
809 arith::populateArithToLLVMConversionPatterns(converter&: typeConverter, patterns);
810 cf::populateControlFlowToLLVMConversionPatterns(converter&: typeConverter, patterns);
811
812 LLVMConversionTarget target(getContext());
813 if (failed(applyPartialConversion(m, target, std::move(patterns))))
814 signalPassFailure();
815 }
816};
817
818struct SetLLVMModuleDataLayoutPass
819 : public impl::SetLLVMModuleDataLayoutPassBase<
820 SetLLVMModuleDataLayoutPass> {
821 using Base::Base;
822
823 /// Run the dialect converter on the module.
824 void runOnOperation() override {
825 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
826 this->dataLayout, [this](const Twine &message) {
827 getOperation().emitError() << message.str();
828 }))) {
829 signalPassFailure();
830 return;
831 }
832 ModuleOp m = getOperation();
833 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
834 StringAttr::get(m.getContext(), this->dataLayout));
835 }
836};
837} // namespace
838
839//===----------------------------------------------------------------------===//
840// ConvertToLLVMPatternInterface implementation
841//===----------------------------------------------------------------------===//
842
843namespace {
844/// Implement the interface to convert Func to LLVM.
845struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
846 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
847 /// Hook for derived dialect interface to provide conversion patterns
848 /// and mark dialect legal for the conversion target.
849 void populateConvertToLLVMConversionPatterns(
850 ConversionTarget &target, LLVMTypeConverter &typeConverter,
851 RewritePatternSet &patterns) const final {
852 populateFuncToLLVMConversionPatterns(converter&: typeConverter, patterns);
853 }
854};
855} // namespace
856
857void mlir::registerConvertFuncToLLVMInterface(DialectRegistry &registry) {
858 registry.addExtension(extensionFn: +[](MLIRContext *ctx, func::FuncDialect *dialect) {
859 dialect->addInterfaces<FuncToLLVMDialectInterface>();
860 });
861}
862

source code of mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp