| 1 | //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements helper functions to call common simple C functions in |
| 10 | // LLVMIR (e.g. amon others to support printing and debugging). |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
| 15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 16 | #include "mlir/IR/Builders.h" |
| 17 | #include "mlir/IR/OpDefinition.h" |
| 18 | #include "mlir/Support/LLVM.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::LLVM; |
| 22 | |
| 23 | /// Helper functions to lookup or create the declaration for commonly used |
| 24 | /// external C function calls. The list of functions provided here must be |
| 25 | /// implemented separately (e.g. as part of a support runtime library or as |
| 26 | /// part of the libc). |
| 27 | static constexpr llvm::StringRef kPrintI64 = "printI64" ; |
| 28 | static constexpr llvm::StringRef kPrintU64 = "printU64" ; |
| 29 | static constexpr llvm::StringRef kPrintF16 = "printF16" ; |
| 30 | static constexpr llvm::StringRef kPrintBF16 = "printBF16" ; |
| 31 | static constexpr llvm::StringRef kPrintF32 = "printF32" ; |
| 32 | static constexpr llvm::StringRef kPrintF64 = "printF64" ; |
| 33 | static constexpr llvm::StringRef kPrintString = "printString" ; |
| 34 | static constexpr llvm::StringRef kPrintOpen = "printOpen" ; |
| 35 | static constexpr llvm::StringRef kPrintClose = "printClose" ; |
| 36 | static constexpr llvm::StringRef kPrintComma = "printComma" ; |
| 37 | static constexpr llvm::StringRef kPrintNewline = "printNewline" ; |
| 38 | static constexpr llvm::StringRef kMalloc = "malloc" ; |
| 39 | static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc" ; |
| 40 | static constexpr llvm::StringRef kFree = "free" ; |
| 41 | static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc" ; |
| 42 | static constexpr llvm::StringRef kGenericAlignedAlloc = |
| 43 | "_mlir_memref_to_llvm_aligned_alloc" ; |
| 44 | static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free" ; |
| 45 | static constexpr llvm::StringRef kMemRefCopy = "memrefCopy" ; |
| 46 | |
| 47 | namespace { |
| 48 | /// Search for an LLVMFuncOp with a given name within an operation with the |
| 49 | /// SymbolTable trait. An optional collection of cached symbol tables can be |
| 50 | /// given to avoid a linear scan of the symbol table operation. |
| 51 | LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp, |
| 52 | SymbolTableCollection *symbolTables = nullptr) { |
| 53 | if (symbolTables) { |
| 54 | return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>( |
| 55 | symbolTableOp, name: StringAttr::get(context: symbolTableOp->getContext(), bytes: name)); |
| 56 | } |
| 57 | |
| 58 | return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>( |
| 59 | Val: SymbolTable::lookupSymbolIn(op: symbolTableOp, symbol: name)); |
| 60 | } |
| 61 | } // namespace |
| 62 | |
| 63 | /// Generic print function lookupOrCreate helper. |
| 64 | FailureOr<LLVM::LLVMFuncOp> |
| 65 | mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, |
| 66 | ArrayRef<Type> paramTypes, Type resultType, |
| 67 | bool isVarArg, bool isReserved, |
| 68 | SymbolTableCollection *symbolTables) { |
| 69 | assert(moduleOp->hasTrait<OpTrait::SymbolTable>() && |
| 70 | "expected SymbolTable operation" ); |
| 71 | auto func = lookupFuncOp(name, symbolTableOp: moduleOp, symbolTables); |
| 72 | auto funcT = LLVMFunctionType::get(result: resultType, arguments: paramTypes, isVarArg); |
| 73 | // Assert the signature of the found function is same as expected |
| 74 | if (func) { |
| 75 | if (funcT != func.getFunctionType()) { |
| 76 | if (isReserved) { |
| 77 | func.emitError(message: "redefinition of reserved function '" ) |
| 78 | << name << "' of different type " << func.getFunctionType() |
| 79 | << " is prohibited" ; |
| 80 | } else { |
| 81 | func.emitError(message: "redefinition of function '" ) |
| 82 | << name << "' of different type " << funcT << " is prohibited" ; |
| 83 | } |
| 84 | return failure(); |
| 85 | } |
| 86 | return func; |
| 87 | } |
| 88 | |
| 89 | OpBuilder::InsertionGuard g(b); |
| 90 | assert(!moduleOp->getRegion(0).empty() && "expected non-empty region" ); |
| 91 | b.setInsertionPointToStart(&moduleOp->getRegion(index: 0).front()); |
| 92 | auto funcOp = b.create<LLVM::LLVMFuncOp>( |
| 93 | location: moduleOp->getLoc(), args&: name, |
| 94 | args: LLVM::LLVMFunctionType::get(result: resultType, arguments: paramTypes, isVarArg)); |
| 95 | |
| 96 | if (symbolTables) { |
| 97 | SymbolTable &symbolTable = symbolTables->getSymbolTable(op: moduleOp); |
| 98 | symbolTable.insert(symbol: funcOp, insertPt: moduleOp->getRegion(index: 0).front().begin()); |
| 99 | } |
| 100 | |
| 101 | return funcOp; |
| 102 | } |
| 103 | |
| 104 | static FailureOr<LLVM::LLVMFuncOp> |
| 105 | lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name, |
| 106 | ArrayRef<Type> paramTypes, Type resultType, |
| 107 | SymbolTableCollection *symbolTables) { |
| 108 | return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType, |
| 109 | /*isVarArg=*/false, /*isReserved=*/true, |
| 110 | symbolTables); |
| 111 | } |
| 112 | |
| 113 | FailureOr<LLVM::LLVMFuncOp> |
| 114 | mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, |
| 115 | SymbolTableCollection *symbolTables) { |
| 116 | return lookupOrCreateReservedFn( |
| 117 | b, moduleOp, name: kPrintI64, paramTypes: IntegerType::get(context: moduleOp->getContext(), width: 64), |
| 118 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 119 | } |
| 120 | |
| 121 | FailureOr<LLVM::LLVMFuncOp> |
| 122 | mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, |
| 123 | SymbolTableCollection *symbolTables) { |
| 124 | return lookupOrCreateReservedFn( |
| 125 | b, moduleOp, name: kPrintU64, paramTypes: IntegerType::get(context: moduleOp->getContext(), width: 64), |
| 126 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 127 | } |
| 128 | |
| 129 | FailureOr<LLVM::LLVMFuncOp> |
| 130 | mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, |
| 131 | SymbolTableCollection *symbolTables) { |
| 132 | return lookupOrCreateReservedFn( |
| 133 | b, moduleOp, name: kPrintF16, |
| 134 | paramTypes: IntegerType::get(context: moduleOp->getContext(), width: 16), // bits! |
| 135 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 136 | } |
| 137 | |
| 138 | FailureOr<LLVM::LLVMFuncOp> |
| 139 | mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, |
| 140 | SymbolTableCollection *symbolTables) { |
| 141 | return lookupOrCreateReservedFn( |
| 142 | b, moduleOp, name: kPrintBF16, |
| 143 | paramTypes: IntegerType::get(context: moduleOp->getContext(), width: 16), // bits! |
| 144 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 145 | } |
| 146 | |
| 147 | FailureOr<LLVM::LLVMFuncOp> |
| 148 | mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, |
| 149 | SymbolTableCollection *symbolTables) { |
| 150 | return lookupOrCreateReservedFn( |
| 151 | b, moduleOp, name: kPrintF32, paramTypes: Float32Type::get(context: moduleOp->getContext()), |
| 152 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 153 | } |
| 154 | |
| 155 | FailureOr<LLVM::LLVMFuncOp> |
| 156 | mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, |
| 157 | SymbolTableCollection *symbolTables) { |
| 158 | return lookupOrCreateReservedFn( |
| 159 | b, moduleOp, name: kPrintF64, paramTypes: Float64Type::get(context: moduleOp->getContext()), |
| 160 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 161 | } |
| 162 | |
| 163 | static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { |
| 164 | return LLVM::LLVMPointerType::get(context); |
| 165 | } |
| 166 | |
| 167 | static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) { |
| 168 | // A char pointer and void ptr are the same in LLVM IR. |
| 169 | return getCharPtr(context); |
| 170 | } |
| 171 | |
| 172 | FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn( |
| 173 | OpBuilder &b, Operation *moduleOp, |
| 174 | std::optional<StringRef> runtimeFunctionName, |
| 175 | SymbolTableCollection *symbolTables) { |
| 176 | return lookupOrCreateReservedFn( |
| 177 | b, moduleOp, name: runtimeFunctionName.value_or(u: kPrintString), |
| 178 | paramTypes: getCharPtr(context: moduleOp->getContext()), |
| 179 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 180 | } |
| 181 | |
| 182 | FailureOr<LLVM::LLVMFuncOp> |
| 183 | mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, |
| 184 | SymbolTableCollection *symbolTables) { |
| 185 | return lookupOrCreateReservedFn( |
| 186 | b, moduleOp, name: kPrintOpen, paramTypes: {}, |
| 187 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 188 | } |
| 189 | |
| 190 | FailureOr<LLVM::LLVMFuncOp> |
| 191 | mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, |
| 192 | SymbolTableCollection *symbolTables) { |
| 193 | return lookupOrCreateReservedFn( |
| 194 | b, moduleOp, name: kPrintClose, paramTypes: {}, |
| 195 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 196 | } |
| 197 | |
| 198 | FailureOr<LLVM::LLVMFuncOp> |
| 199 | mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, |
| 200 | SymbolTableCollection *symbolTables) { |
| 201 | return lookupOrCreateReservedFn( |
| 202 | b, moduleOp, name: kPrintComma, paramTypes: {}, |
| 203 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 204 | } |
| 205 | |
| 206 | FailureOr<LLVM::LLVMFuncOp> |
| 207 | mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, |
| 208 | SymbolTableCollection *symbolTables) { |
| 209 | return lookupOrCreateReservedFn( |
| 210 | b, moduleOp, name: kPrintNewline, paramTypes: {}, |
| 211 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 212 | } |
| 213 | |
| 214 | FailureOr<LLVM::LLVMFuncOp> |
| 215 | mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, |
| 216 | Type indexType, |
| 217 | SymbolTableCollection *symbolTables) { |
| 218 | return lookupOrCreateReservedFn(b, moduleOp, name: kMalloc, paramTypes: indexType, |
| 219 | resultType: getVoidPtr(context: moduleOp->getContext()), |
| 220 | symbolTables); |
| 221 | } |
| 222 | |
| 223 | FailureOr<LLVM::LLVMFuncOp> |
| 224 | mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, |
| 225 | Type indexType, |
| 226 | SymbolTableCollection *symbolTables) { |
| 227 | return lookupOrCreateReservedFn( |
| 228 | b, moduleOp, name: kAlignedAlloc, paramTypes: {indexType, indexType}, |
| 229 | resultType: getVoidPtr(context: moduleOp->getContext()), symbolTables); |
| 230 | } |
| 231 | |
| 232 | FailureOr<LLVM::LLVMFuncOp> |
| 233 | mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, |
| 234 | SymbolTableCollection *symbolTables) { |
| 235 | return lookupOrCreateReservedFn( |
| 236 | b, moduleOp, name: kFree, paramTypes: getVoidPtr(context: moduleOp->getContext()), |
| 237 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 238 | } |
| 239 | |
| 240 | FailureOr<LLVM::LLVMFuncOp> |
| 241 | mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, |
| 242 | Type indexType, |
| 243 | SymbolTableCollection *symbolTables) { |
| 244 | return lookupOrCreateReservedFn(b, moduleOp, name: kGenericAlloc, paramTypes: indexType, |
| 245 | resultType: getVoidPtr(context: moduleOp->getContext()), |
| 246 | symbolTables); |
| 247 | } |
| 248 | |
| 249 | FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateGenericAlignedAllocFn( |
| 250 | OpBuilder &b, Operation *moduleOp, Type indexType, |
| 251 | SymbolTableCollection *symbolTables) { |
| 252 | return lookupOrCreateReservedFn( |
| 253 | b, moduleOp, name: kGenericAlignedAlloc, paramTypes: {indexType, indexType}, |
| 254 | resultType: getVoidPtr(context: moduleOp->getContext()), symbolTables); |
| 255 | } |
| 256 | |
| 257 | FailureOr<LLVM::LLVMFuncOp> |
| 258 | mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, |
| 259 | SymbolTableCollection *symbolTables) { |
| 260 | return lookupOrCreateReservedFn( |
| 261 | b, moduleOp, name: kGenericFree, paramTypes: getVoidPtr(context: moduleOp->getContext()), |
| 262 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 263 | } |
| 264 | |
| 265 | FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn( |
| 266 | OpBuilder &b, Operation *moduleOp, Type indexType, |
| 267 | Type unrankedDescriptorType, SymbolTableCollection *symbolTables) { |
| 268 | return lookupOrCreateReservedFn( |
| 269 | b, moduleOp, name: kMemRefCopy, |
| 270 | paramTypes: ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType}, |
| 271 | resultType: LLVM::LLVMVoidType::get(ctx: moduleOp->getContext()), symbolTables); |
| 272 | } |
| 273 | |