1//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper functions to call common simple C functions in
10// LLVMIR (e.g. amon others to support printing and debugging).
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/Support/LLVM.h"
19
20using namespace mlir;
21using namespace mlir::LLVM;
22
23/// Helper functions to lookup or create the declaration for commonly used
24/// external C function calls. The list of functions provided here must be
25/// implemented separately (e.g. as part of a support runtime library or as
26/// part of the libc).
27static constexpr llvm::StringRef kPrintI64 = "printI64";
28static constexpr llvm::StringRef kPrintU64 = "printU64";
29static constexpr llvm::StringRef kPrintF16 = "printF16";
30static constexpr llvm::StringRef kPrintBF16 = "printBF16";
31static constexpr llvm::StringRef kPrintF32 = "printF32";
32static constexpr llvm::StringRef kPrintF64 = "printF64";
33static constexpr llvm::StringRef kPrintString = "printString";
34static constexpr llvm::StringRef kPrintOpen = "printOpen";
35static constexpr llvm::StringRef kPrintClose = "printClose";
36static constexpr llvm::StringRef kPrintComma = "printComma";
37static constexpr llvm::StringRef kPrintNewline = "printNewline";
38static constexpr llvm::StringRef kMalloc = "malloc";
39static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
40static constexpr llvm::StringRef kFree = "free";
41static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
42static constexpr llvm::StringRef kGenericAlignedAlloc =
43 "_mlir_memref_to_llvm_aligned_alloc";
44static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
45static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
46
47/// Generic print function lookupOrCreate helper.
48FailureOr<LLVM::LLVMFuncOp>
49mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
50 ArrayRef<Type> paramTypes, Type resultType,
51 bool isVarArg, bool isReserved) {
52 assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
53 "expected SymbolTable operation");
54 auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
55 SymbolTable::lookupSymbolIn(op: moduleOp, symbol: name));
56 auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
57 // Assert the signature of the found function is same as expected
58 if (func) {
59 if (funcT != func.getFunctionType()) {
60 if (isReserved) {
61 func.emitError("redefinition of reserved function '")
62 << name << "' of different type " << func.getFunctionType()
63 << " is prohibited";
64 } else {
65 func.emitError("redefinition of function '")
66 << name << "' of different type " << funcT << " is prohibited";
67 }
68 return failure();
69 }
70 return func;
71 }
72
73 OpBuilder::InsertionGuard g(b);
74 assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
75 b.setInsertionPointToStart(&moduleOp->getRegion(index: 0).front());
76 return b.create<LLVM::LLVMFuncOp>(
77 moduleOp->getLoc(), name,
78 LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
79}
80
81static FailureOr<LLVM::LLVMFuncOp>
82lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
83 ArrayRef<Type> paramTypes, Type resultType) {
84 return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
85 /*isVarArg=*/false, /*isReserved=*/true);
86}
87
88FailureOr<LLVM::LLVMFuncOp>
89mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) {
90 return lookupOrCreateReservedFn(
91 b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
92 LLVM::LLVMVoidType::get(moduleOp->getContext()));
93}
94
95FailureOr<LLVM::LLVMFuncOp>
96mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) {
97 return lookupOrCreateReservedFn(
98 b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
99 LLVM::LLVMVoidType::get(moduleOp->getContext()));
100}
101
102FailureOr<LLVM::LLVMFuncOp>
103mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) {
104 return lookupOrCreateReservedFn(
105 b, moduleOp, kPrintF16,
106 IntegerType::get(moduleOp->getContext(), 16), // bits!
107 LLVM::LLVMVoidType::get(moduleOp->getContext()));
108}
109
110FailureOr<LLVM::LLVMFuncOp>
111mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) {
112 return lookupOrCreateReservedFn(
113 b, moduleOp, kPrintBF16,
114 IntegerType::get(moduleOp->getContext(), 16), // bits!
115 LLVM::LLVMVoidType::get(moduleOp->getContext()));
116}
117
118FailureOr<LLVM::LLVMFuncOp>
119mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) {
120 return lookupOrCreateReservedFn(
121 b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
122 LLVM::LLVMVoidType::get(moduleOp->getContext()));
123}
124
125FailureOr<LLVM::LLVMFuncOp>
126mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) {
127 return lookupOrCreateReservedFn(
128 b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
129 LLVM::LLVMVoidType::get(moduleOp->getContext()));
130}
131
132static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
133 return LLVM::LLVMPointerType::get(context);
134}
135
136static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
137 // A char pointer and void ptr are the same in LLVM IR.
138 return getCharPtr(context);
139}
140
141FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
142 OpBuilder &b, Operation *moduleOp,
143 std::optional<StringRef> runtimeFunctionName) {
144 return lookupOrCreateReservedFn(
145 b, moduleOp, runtimeFunctionName.value_or(kPrintString),
146 getCharPtr(moduleOp->getContext()),
147 LLVM::LLVMVoidType::get(moduleOp->getContext()));
148}
149
150FailureOr<LLVM::LLVMFuncOp>
151mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp) {
152 return lookupOrCreateReservedFn(
153 b, moduleOp, name: kPrintOpen, paramTypes: {},
154 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()));
155}
156
157FailureOr<LLVM::LLVMFuncOp>
158mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp) {
159 return lookupOrCreateReservedFn(
160 b, moduleOp, name: kPrintClose, paramTypes: {},
161 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()));
162}
163
164FailureOr<LLVM::LLVMFuncOp>
165mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp) {
166 return lookupOrCreateReservedFn(
167 b, moduleOp, name: kPrintComma, paramTypes: {},
168 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()));
169}
170
171FailureOr<LLVM::LLVMFuncOp>
172mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp) {
173 return lookupOrCreateReservedFn(
174 b, moduleOp, name: kPrintNewline, paramTypes: {},
175 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()));
176}
177
178FailureOr<LLVM::LLVMFuncOp>
179mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp,
180 Type indexType) {
181 return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
182 getVoidPtr(moduleOp->getContext()));
183}
184
185FailureOr<LLVM::LLVMFuncOp>
186mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
187 Type indexType) {
188 return lookupOrCreateReservedFn(b, moduleOp, kAlignedAlloc,
189 {indexType, indexType},
190 getVoidPtr(moduleOp->getContext()));
191}
192
193FailureOr<LLVM::LLVMFuncOp>
194mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp) {
195 return lookupOrCreateReservedFn(
196 b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
197 LLVM::LLVMVoidType::get(moduleOp->getContext()));
198}
199
200FailureOr<LLVM::LLVMFuncOp>
201mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp,
202 Type indexType) {
203 return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
204 getVoidPtr(moduleOp->getContext()));
205}
206
207FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(
208 OpBuilder &b, Operation *moduleOp, Type indexType) {
209 return lookupOrCreateReservedFn(b, moduleOp, kGenericAlignedAlloc,
210 {indexType, indexType},
211 getVoidPtr(moduleOp->getContext()));
212}
213
214FailureOr<LLVM::LLVMFuncOp>
215mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp) {
216 return lookupOrCreateReservedFn(
217 b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
218 LLVM::LLVMVoidType::get(moduleOp->getContext()));
219}
220
221FailureOr<LLVM::LLVMFuncOp>
222mlir::LLVM::lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp,
223 Type indexType,
224 Type unrankedDescriptorType) {
225 return lookupOrCreateReservedFn(
226 b, moduleOp, name: kMemRefCopy,
227 paramTypes: ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
228 resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()));
229}
230

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp