1 | //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements helper functions to call common simple C functions in |
10 | // LLVMIR (e.g. amon others to support printing and debugging). |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/OpDefinition.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::LLVM; |
22 | |
23 | /// Helper functions to lookup or create the declaration for commonly used |
24 | /// external C function calls. The list of functions provided here must be |
25 | /// implemented separately (e.g. as part of a support runtime library or as |
26 | /// part of the libc). |
27 | static constexpr llvm::StringRef kPrintI64 = "printI64"; |
28 | static constexpr llvm::StringRef kPrintU64 = "printU64"; |
29 | static constexpr llvm::StringRef kPrintF16 = "printF16"; |
30 | static constexpr llvm::StringRef kPrintBF16 = "printBF16"; |
31 | static constexpr llvm::StringRef kPrintF32 = "printF32"; |
32 | static constexpr llvm::StringRef kPrintF64 = "printF64"; |
33 | static constexpr llvm::StringRef kPrintString = "printString"; |
34 | static constexpr llvm::StringRef kPrintOpen = "printOpen"; |
35 | static constexpr llvm::StringRef kPrintClose = "printClose"; |
36 | static constexpr llvm::StringRef kPrintComma = "printComma"; |
37 | static constexpr llvm::StringRef kPrintNewline = "printNewline"; |
38 | static constexpr llvm::StringRef kMalloc = "malloc"; |
39 | static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; |
40 | static constexpr llvm::StringRef kFree = "free"; |
41 | static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc"; |
42 | static constexpr llvm::StringRef kGenericAlignedAlloc = |
43 | "_mlir_memref_to_llvm_aligned_alloc"; |
44 | static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free"; |
45 | static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; |
46 | |
47 | /// Generic print function lookupOrCreate helper. |
48 | FailureOr<LLVM::LLVMFuncOp> |
49 | mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, |
50 | ArrayRef<Type> paramTypes, Type resultType, |
51 | bool isVarArg, bool isReserved) { |
52 | assert(moduleOp->hasTrait<OpTrait::SymbolTable>() && |
53 | "expected SymbolTable operation"); |
54 | auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>( |
55 | SymbolTable::lookupSymbolIn(op: moduleOp, symbol: name)); |
56 | auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg); |
57 | // Assert the signature of the found function is same as expected |
58 | if (func) { |
59 | if (funcT != func.getFunctionType()) { |
60 | if (isReserved) { |
61 | func.emitError("redefinition of reserved function '") |
62 | << name << "' of different type "<< func.getFunctionType() |
63 | << " is prohibited"; |
64 | } else { |
65 | func.emitError("redefinition of function '") |
66 | << name << "' of different type "<< funcT << " is prohibited"; |
67 | } |
68 | return failure(); |
69 | } |
70 | return func; |
71 | } |
72 | |
73 | OpBuilder::InsertionGuard g(b); |
74 | assert(!moduleOp->getRegion(0).empty() && "expected non-empty region"); |
75 | b.setInsertionPointToStart(&moduleOp->getRegion(index: 0).front()); |
76 | return b.create<LLVM::LLVMFuncOp>( |
77 | moduleOp->getLoc(), name, |
78 | LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); |
79 | } |
80 | |
81 | static FailureOr<LLVM::LLVMFuncOp> |
82 | lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name, |
83 | ArrayRef<Type> paramTypes, Type resultType) { |
84 | return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType, |
85 | /*isVarArg=*/false, /*isReserved=*/true); |
86 | } |
87 | |
88 | FailureOr<LLVM::LLVMFuncOp> |
89 | mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) { |
90 | return lookupOrCreateReservedFn( |
91 | b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), |
92 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
93 | } |
94 | |
95 | FailureOr<LLVM::LLVMFuncOp> |
96 | mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) { |
97 | return lookupOrCreateReservedFn( |
98 | b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), |
99 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
100 | } |
101 | |
102 | FailureOr<LLVM::LLVMFuncOp> |
103 | mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) { |
104 | return lookupOrCreateReservedFn( |
105 | b, moduleOp, kPrintF16, |
106 | IntegerType::get(moduleOp->getContext(), 16), // bits! |
107 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
108 | } |
109 | |
110 | FailureOr<LLVM::LLVMFuncOp> |
111 | mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) { |
112 | return lookupOrCreateReservedFn( |
113 | b, moduleOp, kPrintBF16, |
114 | IntegerType::get(moduleOp->getContext(), 16), // bits! |
115 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
116 | } |
117 | |
118 | FailureOr<LLVM::LLVMFuncOp> |
119 | mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) { |
120 | return lookupOrCreateReservedFn( |
121 | b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), |
122 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
123 | } |
124 | |
125 | FailureOr<LLVM::LLVMFuncOp> |
126 | mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) { |
127 | return lookupOrCreateReservedFn( |
128 | b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), |
129 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
130 | } |
131 | |
132 | static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { |
133 | return LLVM::LLVMPointerType::get(context); |
134 | } |
135 | |
136 | static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) { |
137 | // A char pointer and void ptr are the same in LLVM IR. |
138 | return getCharPtr(context); |
139 | } |
140 | |
141 | FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn( |
142 | OpBuilder &b, Operation *moduleOp, |
143 | std::optional<StringRef> runtimeFunctionName) { |
144 | return lookupOrCreateReservedFn( |
145 | b, moduleOp, runtimeFunctionName.value_or(kPrintString), |
146 | getCharPtr(moduleOp->getContext()), |
147 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
148 | } |
149 | |
150 | FailureOr<LLVM::LLVMFuncOp> |
151 | mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp) { |
152 | return lookupOrCreateReservedFn( |
153 | b, moduleOp, name: kPrintOpen, paramTypes: {}, |
154 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext())); |
155 | } |
156 | |
157 | FailureOr<LLVM::LLVMFuncOp> |
158 | mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp) { |
159 | return lookupOrCreateReservedFn( |
160 | b, moduleOp, name: kPrintClose, paramTypes: {}, |
161 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext())); |
162 | } |
163 | |
164 | FailureOr<LLVM::LLVMFuncOp> |
165 | mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp) { |
166 | return lookupOrCreateReservedFn( |
167 | b, moduleOp, name: kPrintComma, paramTypes: {}, |
168 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext())); |
169 | } |
170 | |
171 | FailureOr<LLVM::LLVMFuncOp> |
172 | mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp) { |
173 | return lookupOrCreateReservedFn( |
174 | b, moduleOp, name: kPrintNewline, paramTypes: {}, |
175 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext())); |
176 | } |
177 | |
178 | FailureOr<LLVM::LLVMFuncOp> |
179 | mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, |
180 | Type indexType) { |
181 | return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType, |
182 | getVoidPtr(moduleOp->getContext())); |
183 | } |
184 | |
185 | FailureOr<LLVM::LLVMFuncOp> |
186 | mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, |
187 | Type indexType) { |
188 | return lookupOrCreateReservedFn(b, moduleOp, kAlignedAlloc, |
189 | {indexType, indexType}, |
190 | getVoidPtr(moduleOp->getContext())); |
191 | } |
192 | |
193 | FailureOr<LLVM::LLVMFuncOp> |
194 | mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp) { |
195 | return lookupOrCreateReservedFn( |
196 | b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()), |
197 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
198 | } |
199 | |
200 | FailureOr<LLVM::LLVMFuncOp> |
201 | mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, |
202 | Type indexType) { |
203 | return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType, |
204 | getVoidPtr(moduleOp->getContext())); |
205 | } |
206 | |
207 | FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateGenericAlignedAllocFn( |
208 | OpBuilder &b, Operation *moduleOp, Type indexType) { |
209 | return lookupOrCreateReservedFn(b, moduleOp, kGenericAlignedAlloc, |
210 | {indexType, indexType}, |
211 | getVoidPtr(moduleOp->getContext())); |
212 | } |
213 | |
214 | FailureOr<LLVM::LLVMFuncOp> |
215 | mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp) { |
216 | return lookupOrCreateReservedFn( |
217 | b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), |
218 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
219 | } |
220 | |
221 | FailureOr<LLVM::LLVMFuncOp> |
222 | mlir::LLVM::lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, |
223 | Type indexType, |
224 | Type unrankedDescriptorType) { |
225 | return lookupOrCreateReservedFn( |
226 | b, moduleOp, name: kMemRefCopy, |
227 | paramTypes: ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType}, |
228 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext())); |
229 | } |
230 |
Definitions
- kPrintI64
- kPrintU64
- kPrintF16
- kPrintBF16
- kPrintF32
- kPrintF64
- kPrintString
- kPrintOpen
- kPrintClose
- kPrintComma
- kPrintNewline
- kMalloc
- kAlignedAlloc
- kFree
- kGenericAlloc
- kGenericAlignedAlloc
- kGenericFree
- kMemRefCopy
- lookupOrCreateFn
- lookupOrCreateReservedFn
- lookupOrCreatePrintI64Fn
- lookupOrCreatePrintU64Fn
- lookupOrCreatePrintF16Fn
- lookupOrCreatePrintBF16Fn
- lookupOrCreatePrintF32Fn
- lookupOrCreatePrintF64Fn
- getCharPtr
- getVoidPtr
- lookupOrCreatePrintStringFn
- lookupOrCreatePrintOpenFn
- lookupOrCreatePrintCloseFn
- lookupOrCreatePrintCommaFn
- lookupOrCreatePrintNewlineFn
- lookupOrCreateMallocFn
- lookupOrCreateAlignedAllocFn
- lookupOrCreateFreeFn
- lookupOrCreateGenericAllocFn
- lookupOrCreateGenericAlignedAllocFn
- lookupOrCreateGenericFreeFn
Improve your Profiling and Debugging skills
Find out more