1//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper functions to call common simple C functions in
10// LLVMIR (e.g. amon others to support printing and debugging).
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/Support/LLVM.h"
19
20using namespace mlir;
21using namespace mlir::LLVM;
22
23/// Helper functions to lookup or create the declaration for commonly used
24/// external C function calls. The list of functions provided here must be
25/// implemented separately (e.g. as part of a support runtime library or as
26/// part of the libc).
27static constexpr llvm::StringRef kPrintI64 = "printI64";
28static constexpr llvm::StringRef kPrintU64 = "printU64";
29static constexpr llvm::StringRef kPrintF16 = "printF16";
30static constexpr llvm::StringRef kPrintBF16 = "printBF16";
31static constexpr llvm::StringRef kPrintF32 = "printF32";
32static constexpr llvm::StringRef kPrintF64 = "printF64";
33static constexpr llvm::StringRef kPrintString = "printString";
34static constexpr llvm::StringRef kPrintOpen = "printOpen";
35static constexpr llvm::StringRef kPrintClose = "printClose";
36static constexpr llvm::StringRef kPrintComma = "printComma";
37static constexpr llvm::StringRef kPrintNewline = "printNewline";
38static constexpr llvm::StringRef kMalloc = "malloc";
39static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
40static constexpr llvm::StringRef kFree = "free";
41static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
42static constexpr llvm::StringRef kGenericAlignedAlloc =
43 "_mlir_memref_to_llvm_aligned_alloc";
44static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
45static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
46
47/// Generic print function lookupOrCreate helper.
48LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
49 ArrayRef<Type> paramTypes,
50 Type resultType, bool isVarArg) {
51 auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
52 if (func)
53 return func;
54 OpBuilder b(moduleOp.getBodyRegion());
55 return b.create<LLVM::LLVMFuncOp>(
56 moduleOp->getLoc(), name,
57 LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
58}
59
60LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
61 return lookupOrCreateFn(moduleOp, kPrintI64,
62 IntegerType::get(moduleOp->getContext(), 64),
63 LLVM::LLVMVoidType::get(moduleOp->getContext()));
64}
65
66LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
67 return lookupOrCreateFn(moduleOp, kPrintU64,
68 IntegerType::get(moduleOp->getContext(), 64),
69 LLVM::LLVMVoidType::get(moduleOp->getContext()));
70}
71
72LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
73 return lookupOrCreateFn(moduleOp, kPrintF16,
74 IntegerType::get(moduleOp->getContext(), 16), // bits!
75 LLVM::LLVMVoidType::get(moduleOp->getContext()));
76}
77
78LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
79 return lookupOrCreateFn(moduleOp, kPrintBF16,
80 IntegerType::get(moduleOp->getContext(), 16), // bits!
81 LLVM::LLVMVoidType::get(moduleOp->getContext()));
82}
83
84LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
85 return lookupOrCreateFn(moduleOp, kPrintF32,
86 Float32Type::get(moduleOp->getContext()),
87 LLVM::LLVMVoidType::get(moduleOp->getContext()));
88}
89
90LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
91 return lookupOrCreateFn(moduleOp, kPrintF64,
92 Float64Type::get(moduleOp->getContext()),
93 LLVM::LLVMVoidType::get(moduleOp->getContext()));
94}
95
96static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
97 return LLVM::LLVMPointerType::get(context);
98}
99
100static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
101 // A char pointer and void ptr are the same in LLVM IR.
102 return getCharPtr(context);
103}
104
105LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
106 ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
107 return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
108 getCharPtr(moduleOp->getContext()),
109 LLVM::LLVMVoidType::get(moduleOp->getContext()));
110}
111
112LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
113 return lookupOrCreateFn(moduleOp, kPrintOpen, {},
114 LLVM::LLVMVoidType::get(moduleOp->getContext()));
115}
116
117LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
118 return lookupOrCreateFn(moduleOp, kPrintClose, {},
119 LLVM::LLVMVoidType::get(moduleOp->getContext()));
120}
121
122LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
123 return lookupOrCreateFn(moduleOp, kPrintComma, {},
124 LLVM::LLVMVoidType::get(moduleOp->getContext()));
125}
126
127LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
128 return lookupOrCreateFn(moduleOp, kPrintNewline, {},
129 LLVM::LLVMVoidType::get(moduleOp->getContext()));
130}
131
132LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
133 Type indexType) {
134 return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
135 getVoidPtr(moduleOp->getContext()));
136}
137
138LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
139 Type indexType) {
140 return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
141 getVoidPtr(moduleOp->getContext()));
142}
143
144LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
145 return LLVM::lookupOrCreateFn(
146 moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
147 LLVM::LLVMVoidType::get(moduleOp->getContext()));
148}
149
150LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
151 Type indexType) {
152 return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
153 getVoidPtr(moduleOp->getContext()));
154}
155
156LLVM::LLVMFuncOp
157mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
158 Type indexType) {
159 return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
160 {indexType, indexType},
161 getVoidPtr(moduleOp->getContext()));
162}
163
164LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
165 return LLVM::lookupOrCreateFn(
166 moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
167 LLVM::LLVMVoidType::get(moduleOp->getContext()));
168}
169
170LLVM::LLVMFuncOp
171mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
172 Type unrankedDescriptorType) {
173 return LLVM::lookupOrCreateFn(
174 moduleOp, kMemRefCopy,
175 ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
176 LLVM::LLVMVoidType::get(moduleOp->getContext()));
177}
178

source code of mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp