1 | //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements helper functions to call common simple C functions in |
10 | // LLVMIR (e.g. amon others to support printing and debugging). |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/OpDefinition.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::LLVM; |
22 | |
23 | /// Helper functions to lookup or create the declaration for commonly used |
24 | /// external C function calls. The list of functions provided here must be |
25 | /// implemented separately (e.g. as part of a support runtime library or as |
26 | /// part of the libc). |
27 | static constexpr llvm::StringRef kPrintI64 = "printI64" ; |
28 | static constexpr llvm::StringRef kPrintU64 = "printU64" ; |
29 | static constexpr llvm::StringRef kPrintF16 = "printF16" ; |
30 | static constexpr llvm::StringRef kPrintBF16 = "printBF16" ; |
31 | static constexpr llvm::StringRef kPrintF32 = "printF32" ; |
32 | static constexpr llvm::StringRef kPrintF64 = "printF64" ; |
33 | static constexpr llvm::StringRef kPrintString = "printString" ; |
34 | static constexpr llvm::StringRef kPrintOpen = "printOpen" ; |
35 | static constexpr llvm::StringRef kPrintClose = "printClose" ; |
36 | static constexpr llvm::StringRef kPrintComma = "printComma" ; |
37 | static constexpr llvm::StringRef kPrintNewline = "printNewline" ; |
38 | static constexpr llvm::StringRef kMalloc = "malloc" ; |
39 | static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc" ; |
40 | static constexpr llvm::StringRef kFree = "free" ; |
41 | static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc" ; |
42 | static constexpr llvm::StringRef kGenericAlignedAlloc = |
43 | "_mlir_memref_to_llvm_aligned_alloc" ; |
44 | static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free" ; |
45 | static constexpr llvm::StringRef kMemRefCopy = "memrefCopy" ; |
46 | |
47 | /// Generic print function lookupOrCreate helper. |
48 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name, |
49 | ArrayRef<Type> paramTypes, |
50 | Type resultType, bool isVarArg) { |
51 | auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name); |
52 | if (func) |
53 | return func; |
54 | OpBuilder b(moduleOp.getBodyRegion()); |
55 | return b.create<LLVM::LLVMFuncOp>( |
56 | moduleOp->getLoc(), name, |
57 | LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); |
58 | } |
59 | |
60 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) { |
61 | return lookupOrCreateFn(moduleOp, kPrintI64, |
62 | IntegerType::get(moduleOp->getContext(), 64), |
63 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
64 | } |
65 | |
66 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { |
67 | return lookupOrCreateFn(moduleOp, kPrintU64, |
68 | IntegerType::get(moduleOp->getContext(), 64), |
69 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
70 | } |
71 | |
72 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) { |
73 | return lookupOrCreateFn(moduleOp, kPrintF16, |
74 | IntegerType::get(moduleOp->getContext(), 16), // bits! |
75 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
76 | } |
77 | |
78 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) { |
79 | return lookupOrCreateFn(moduleOp, kPrintBF16, |
80 | IntegerType::get(moduleOp->getContext(), 16), // bits! |
81 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
82 | } |
83 | |
84 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { |
85 | return lookupOrCreateFn(moduleOp, kPrintF32, |
86 | Float32Type::get(moduleOp->getContext()), |
87 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
88 | } |
89 | |
90 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) { |
91 | return lookupOrCreateFn(moduleOp, kPrintF64, |
92 | Float64Type::get(moduleOp->getContext()), |
93 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
94 | } |
95 | |
96 | static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { |
97 | return LLVM::LLVMPointerType::get(context); |
98 | } |
99 | |
100 | static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) { |
101 | // A char pointer and void ptr are the same in LLVM IR. |
102 | return getCharPtr(context); |
103 | } |
104 | |
105 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn( |
106 | ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) { |
107 | return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString), |
108 | getCharPtr(moduleOp->getContext()), |
109 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
110 | } |
111 | |
112 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { |
113 | return lookupOrCreateFn(moduleOp, kPrintOpen, {}, |
114 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
115 | } |
116 | |
117 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) { |
118 | return lookupOrCreateFn(moduleOp, kPrintClose, {}, |
119 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
120 | } |
121 | |
122 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) { |
123 | return lookupOrCreateFn(moduleOp, kPrintComma, {}, |
124 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
125 | } |
126 | |
127 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) { |
128 | return lookupOrCreateFn(moduleOp, kPrintNewline, {}, |
129 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
130 | } |
131 | |
132 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp, |
133 | Type indexType) { |
134 | return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType, |
135 | getVoidPtr(moduleOp->getContext())); |
136 | } |
137 | |
138 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, |
139 | Type indexType) { |
140 | return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType}, |
141 | getVoidPtr(moduleOp->getContext())); |
142 | } |
143 | |
144 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) { |
145 | return LLVM::lookupOrCreateFn( |
146 | moduleOp, kFree, getVoidPtr(moduleOp->getContext()), |
147 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
148 | } |
149 | |
150 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp, |
151 | Type indexType) { |
152 | return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType, |
153 | getVoidPtr(moduleOp->getContext())); |
154 | } |
155 | |
156 | LLVM::LLVMFuncOp |
157 | mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp, |
158 | Type indexType) { |
159 | return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc, |
160 | {indexType, indexType}, |
161 | getVoidPtr(moduleOp->getContext())); |
162 | } |
163 | |
164 | LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) { |
165 | return LLVM::lookupOrCreateFn( |
166 | moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), |
167 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
168 | } |
169 | |
170 | LLVM::LLVMFuncOp |
171 | mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, |
172 | Type unrankedDescriptorType) { |
173 | return LLVM::lookupOrCreateFn( |
174 | moduleOp, kMemRefCopy, |
175 | ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType}, |
176 | LLVM::LLVMVoidType::get(moduleOp->getContext())); |
177 | } |
178 | |