1//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper functions to call common simple C functions in
10// LLVMIR (e.g. amon others to support printing and debugging).
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/Support/LLVM.h"
19
20using namespace mlir;
21using namespace mlir::LLVM;
22
23/// Helper functions to lookup or create the declaration for commonly used
24/// external C function calls. The list of functions provided here must be
25/// implemented separately (e.g. as part of a support runtime library or as
26/// part of the libc).
27static constexpr llvm::StringRef kPrintI64 = "printI64";
28static constexpr llvm::StringRef kPrintU64 = "printU64";
29static constexpr llvm::StringRef kPrintF16 = "printF16";
30static constexpr llvm::StringRef kPrintBF16 = "printBF16";
31static constexpr llvm::StringRef kPrintF32 = "printF32";
32static constexpr llvm::StringRef kPrintF64 = "printF64";
33static constexpr llvm::StringRef kPrintString = "printString";
34static constexpr llvm::StringRef kPrintOpen = "printOpen";
35static constexpr llvm::StringRef kPrintClose = "printClose";
36static constexpr llvm::StringRef kPrintComma = "printComma";
37static constexpr llvm::StringRef kPrintNewline = "printNewline";
38static constexpr llvm::StringRef kMalloc = "malloc";
39static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
40static constexpr llvm::StringRef kFree = "free";
41static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
42static constexpr llvm::StringRef kGenericAlignedAlloc =
43 "_mlir_memref_to_llvm_aligned_alloc";
44static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
45static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
46
47namespace {
48/// Search for an LLVMFuncOp with a given name within an operation with the
49/// SymbolTable trait. An optional collection of cached symbol tables can be
50/// given to avoid a linear scan of the symbol table operation.
51LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
52 SymbolTableCollection *symbolTables = nullptr) {
53 if (symbolTables) {
54 return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
55 symbolTableOp, name: StringAttr::get(context: symbolTableOp->getContext(), bytes: name));
56 }
57
58 return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
59 Val: SymbolTable::lookupSymbolIn(op: symbolTableOp, symbol: name));
60}
61} // namespace
62
63/// Generic print function lookupOrCreate helper.
64FailureOr<LLVM::LLVMFuncOp>
65mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
66 ArrayRef<Type> paramTypes, Type resultType,
67 bool isVarArg, bool isReserved,
68 SymbolTableCollection *symbolTables) {
69 assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
70 "expected SymbolTable operation");
71 auto func = lookupFuncOp(name, symbolTableOp: moduleOp, symbolTables);
72 auto funcT = LLVMFunctionType::get(result: resultType, arguments: paramTypes, isVarArg);
73 // Assert the signature of the found function is same as expected
74 if (func) {
75 if (funcT != func.getFunctionType()) {
76 if (isReserved) {
77 func.emitError(message: "redefinition of reserved function '")
78 << name << "' of different type " << func.getFunctionType()
79 << " is prohibited";
80 } else {
81 func.emitError(message: "redefinition of function '")
82 << name << "' of different type " << funcT << " is prohibited";
83 }
84 return failure();
85 }
86 return func;
87 }
88
89 OpBuilder::InsertionGuard g(b);
90 assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
91 b.setInsertionPointToStart(&moduleOp->getRegion(index: 0).front());
92 auto funcOp = b.create<LLVM::LLVMFuncOp>(
93 location: moduleOp->getLoc(), args&: name,
94 args: LLVM::LLVMFunctionType::get(result: resultType, arguments: paramTypes, isVarArg));
95
96 if (symbolTables) {
97 SymbolTable &symbolTable = symbolTables->getSymbolTable(op: moduleOp);
98 symbolTable.insert(symbol: funcOp, insertPt: moduleOp->getRegion(index: 0).front().begin());
99 }
100
101 return funcOp;
102}
103
104static FailureOr<LLVM::LLVMFuncOp>
105lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
106 ArrayRef<Type> paramTypes, Type resultType,
107 SymbolTableCollection *symbolTables) {
108 return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
109 /*isVarArg=*/false, /*isReserved=*/true,
110 symbolTables);
111}
112
113FailureOr<LLVM::LLVMFuncOp>
114mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
115 SymbolTableCollection *symbolTables) {
116 return lookupOrCreateReservedFn(
117 b, moduleOp, name: kPrintI64, paramTypes: IntegerType::get(context: moduleOp->getContext(), width: 64),
118 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
119}
120
121FailureOr<LLVM::LLVMFuncOp>
122mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
123 SymbolTableCollection *symbolTables) {
124 return lookupOrCreateReservedFn(
125 b, moduleOp, name: kPrintU64, paramTypes: IntegerType::get(context: moduleOp->getContext(), width: 64),
126 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
127}
128
129FailureOr<LLVM::LLVMFuncOp>
130mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
131 SymbolTableCollection *symbolTables) {
132 return lookupOrCreateReservedFn(
133 b, moduleOp, name: kPrintF16,
134 paramTypes: IntegerType::get(context: moduleOp->getContext(), width: 16), // bits!
135 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
136}
137
138FailureOr<LLVM::LLVMFuncOp>
139mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
140 SymbolTableCollection *symbolTables) {
141 return lookupOrCreateReservedFn(
142 b, moduleOp, name: kPrintBF16,
143 paramTypes: IntegerType::get(context: moduleOp->getContext(), width: 16), // bits!
144 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
145}
146
147FailureOr<LLVM::LLVMFuncOp>
148mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
149 SymbolTableCollection *symbolTables) {
150 return lookupOrCreateReservedFn(
151 b, moduleOp, name: kPrintF32, paramTypes: Float32Type::get(context: moduleOp->getContext()),
152 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
153}
154
155FailureOr<LLVM::LLVMFuncOp>
156mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
157 SymbolTableCollection *symbolTables) {
158 return lookupOrCreateReservedFn(
159 b, moduleOp, name: kPrintF64, paramTypes: Float64Type::get(context: moduleOp->getContext()),
160 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
161}
162
163static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
164 return LLVM::LLVMPointerType::get(context);
165}
166
167static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
168 // A char pointer and void ptr are the same in LLVM IR.
169 return getCharPtr(context);
170}
171
172FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
173 OpBuilder &b, Operation *moduleOp,
174 std::optional<StringRef> runtimeFunctionName,
175 SymbolTableCollection *symbolTables) {
176 return lookupOrCreateReservedFn(
177 b, moduleOp, name: runtimeFunctionName.value_or(u: kPrintString),
178 paramTypes: getCharPtr(context: moduleOp->getContext()),
179 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
180}
181
182FailureOr<LLVM::LLVMFuncOp>
183mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
184 SymbolTableCollection *symbolTables) {
185 return lookupOrCreateReservedFn(
186 b, moduleOp, name: kPrintOpen, paramTypes: {},
187 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
188}
189
190FailureOr<LLVM::LLVMFuncOp>
191mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
192 SymbolTableCollection *symbolTables) {
193 return lookupOrCreateReservedFn(
194 b, moduleOp, name: kPrintClose, paramTypes: {},
195 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
196}
197
198FailureOr<LLVM::LLVMFuncOp>
199mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
200 SymbolTableCollection *symbolTables) {
201 return lookupOrCreateReservedFn(
202 b, moduleOp, name: kPrintComma, paramTypes: {},
203 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
204}
205
206FailureOr<LLVM::LLVMFuncOp>
207mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
208 SymbolTableCollection *symbolTables) {
209 return lookupOrCreateReservedFn(
210 b, moduleOp, name: kPrintNewline, paramTypes: {},
211 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
212}
213
214FailureOr<LLVM::LLVMFuncOp>
215mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp,
216 Type indexType,
217 SymbolTableCollection *symbolTables) {
218 return lookupOrCreateReservedFn(b, moduleOp, name: kMalloc, paramTypes: indexType,
219 resultType: getVoidPtr(context: moduleOp->getContext()),
220 symbolTables);
221}
222
223FailureOr<LLVM::LLVMFuncOp>
224mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
225 Type indexType,
226 SymbolTableCollection *symbolTables) {
227 return lookupOrCreateReservedFn(
228 b, moduleOp, name: kAlignedAlloc, paramTypes: {indexType, indexType},
229 resultType: getVoidPtr(context: moduleOp->getContext()), symbolTables);
230}
231
232FailureOr<LLVM::LLVMFuncOp>
233mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
234 SymbolTableCollection *symbolTables) {
235 return lookupOrCreateReservedFn(
236 b, moduleOp, name: kFree, paramTypes: getVoidPtr(context: moduleOp->getContext()),
237 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
238}
239
240FailureOr<LLVM::LLVMFuncOp>
241mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp,
242 Type indexType,
243 SymbolTableCollection *symbolTables) {
244 return lookupOrCreateReservedFn(b, moduleOp, name: kGenericAlloc, paramTypes: indexType,
245 resultType: getVoidPtr(context: moduleOp->getContext()),
246 symbolTables);
247}
248
249FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(
250 OpBuilder &b, Operation *moduleOp, Type indexType,
251 SymbolTableCollection *symbolTables) {
252 return lookupOrCreateReservedFn(
253 b, moduleOp, name: kGenericAlignedAlloc, paramTypes: {indexType, indexType},
254 resultType: getVoidPtr(context: moduleOp->getContext()), symbolTables);
255}
256
257FailureOr<LLVM::LLVMFuncOp>
258mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
259 SymbolTableCollection *symbolTables) {
260 return lookupOrCreateReservedFn(
261 b, moduleOp, name: kGenericFree, paramTypes: getVoidPtr(context: moduleOp->getContext()),
262 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
263}
264
265FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn(
266 OpBuilder &b, Operation *moduleOp, Type indexType,
267 Type unrankedDescriptorType, SymbolTableCollection *symbolTables) {
268 return lookupOrCreateReservedFn(
269 b, moduleOp, name: kMemRefCopy,
270 paramTypes: ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
271 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables);
272}
273

source code of mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp