1//===- Invoke.cpp ------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
10#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
11#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
12#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
13#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
14#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Linalg/Passes.h"
17#include "mlir/ExecutionEngine/CRunnerUtils.h"
18#include "mlir/ExecutionEngine/ExecutionEngine.h"
19#include "mlir/ExecutionEngine/MemRefUtils.h"
20#include "mlir/ExecutionEngine/RunnerUtils.h"
21#include "mlir/IR/MLIRContext.h"
22#include "mlir/InitAllDialects.h"
23#include "mlir/Parser/Parser.h"
24#include "mlir/Pass/PassManager.h"
25#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
26#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
27#include "mlir/Target/LLVMIR/Export.h"
28#include "llvm/Support/TargetSelect.h"
29#include "llvm/Support/raw_ostream.h"
30
31#include "gmock/gmock.h"
32
33// SPARC currently lacks JIT support.
34#ifdef __sparc__
35#define SKIP_WITHOUT_JIT(x) DISABLED_##x
36#else
37#define SKIP_WITHOUT_JIT(x) x
38#endif
39
40using namespace mlir;
41
42// The JIT isn't supported on Windows at that time
43#ifndef _WIN32
44
45static struct LLVMInitializer {
46 LLVMInitializer() {
47 llvm::InitializeNativeTarget();
48 llvm::InitializeNativeTargetAsmPrinter();
49 }
50} initializer;
51
52/// Simple conversion pipeline for the purpose of testing sources written in
53/// dialects lowering to LLVM Dialect.
54static LogicalResult lowerToLLVMDialect(ModuleOp module) {
55 PassManager pm(module->getName());
56 pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
57 pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass());
58 pm.addPass(mlir::createConvertFuncToLLVMPass());
59 pm.addPass(pass: mlir::createReconcileUnrealizedCastsPass());
60 return pm.run(op: module);
61}
62
63TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) {
64 std::string moduleStr = R"mlir(
65 func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
66 %res = arith.addi %arg0, %arg0 : i32
67 return %res : i32
68 }
69 )mlir";
70 DialectRegistry registry;
71 registerAllDialects(registry);
72 registerBuiltinDialectTranslation(registry);
73 registerLLVMDialectTranslation(registry);
74 MLIRContext context(registry);
75 OwningOpRef<ModuleOp> module =
76 parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
77 ASSERT_TRUE(!!module);
78 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
79 auto jitOrError = ExecutionEngine::create(op: *module);
80 ASSERT_TRUE(!!jitOrError);
81 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
82 // The result of the function must be passed as output argument.
83 int result = 0;
84 llvm::Error error =
85 jit->invoke(funcName: "foo", args: 42, args: ExecutionEngine::Result<int>(result));
86 ASSERT_TRUE(!error);
87 ASSERT_EQ(result, 42 + 42);
88}
89
90TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) {
91 std::string moduleStr = R"mlir(
92 func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
93 %res = arith.subf %arg0, %arg1 : f32
94 return %res : f32
95 }
96 )mlir";
97 DialectRegistry registry;
98 registerAllDialects(registry);
99 registerBuiltinDialectTranslation(registry);
100 registerLLVMDialectTranslation(registry);
101 MLIRContext context(registry);
102 OwningOpRef<ModuleOp> module =
103 parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
104 ASSERT_TRUE(!!module);
105 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
106 auto jitOrError = ExecutionEngine::create(op: *module);
107 ASSERT_TRUE(!!jitOrError);
108 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
109 // The result of the function must be passed as output argument.
110 float result = -1;
111 llvm::Error error =
112 jit->invoke(funcName: "foo", args: 43.0f, args: 1.0f, args: ExecutionEngine::result(t&: result));
113 ASSERT_TRUE(!error);
114 ASSERT_EQ(result, 42.f);
115}
116
117TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) {
118 OwningMemRef<float, 0> a({});
119 a[{}] = 42.;
120 ASSERT_EQ(*a->data, 42);
121 a[{}] = 0;
122 std::string moduleStr = R"mlir(
123 func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
124 %cst42 = arith.constant 42.0 : f32
125 memref.store %cst42, %arg0[] : memref<f32>
126 return
127 }
128 )mlir";
129 DialectRegistry registry;
130 registerAllDialects(registry);
131 registerBuiltinDialectTranslation(registry);
132 registerLLVMDialectTranslation(registry);
133 MLIRContext context(registry);
134 auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
135 ASSERT_TRUE(!!module);
136 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
137 auto jitOrError = ExecutionEngine::create(op: *module);
138 ASSERT_TRUE(!!jitOrError);
139 auto jit = std::move(jitOrError.get());
140
141 llvm::Error error = jit->invoke("zero_ranked", &*a);
142 ASSERT_TRUE(!error);
143 EXPECT_EQ((a[{}]), 42.);
144 for (float &elt : *a)
145 EXPECT_EQ(&elt, &(a[{}]));
146}
147
148TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) {
149 int64_t shape[] = {9};
150 OwningMemRef<float, 1> a(shape);
151 int count = 1;
152 for (float &elt : *a) {
153 EXPECT_EQ(&elt, &(a[{count - 1}]));
154 elt = count++;
155 }
156
157 std::string moduleStr = R"mlir(
158 func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
159 %cst42 = arith.constant 42.0 : f32
160 %cst5 = arith.constant 5 : index
161 memref.store %cst42, %arg0[%cst5] : memref<?xf32>
162 return
163 }
164 )mlir";
165 DialectRegistry registry;
166 registerAllDialects(registry);
167 registerBuiltinDialectTranslation(registry);
168 registerLLVMDialectTranslation(registry);
169 MLIRContext context(registry);
170 auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
171 ASSERT_TRUE(!!module);
172 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
173 auto jitOrError = ExecutionEngine::create(op: *module);
174 ASSERT_TRUE(!!jitOrError);
175 auto jit = std::move(jitOrError.get());
176
177 llvm::Error error = jit->invoke("one_ranked", &*a);
178 ASSERT_TRUE(!error);
179 count = 1;
180 for (float &elt : *a) {
181 if (count == 6)
182 EXPECT_EQ(elt, 42.);
183 else
184 EXPECT_EQ(elt, count);
185 count++;
186 }
187}
188
189TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
190 constexpr int k = 3;
191 constexpr int m = 7;
192 // Prepare arguments beforehand.
193 auto init = [=](float &elt, ArrayRef<int64_t> indices) {
194 assert(indices.size() == 2);
195 elt = m * indices[0] + indices[1];
196 };
197 int64_t shape[] = {k, m};
198 int64_t shapeAlloc[] = {k + 1, m + 1};
199 OwningMemRef<float, 2> a(shape, shapeAlloc, init);
200 ASSERT_EQ(a->sizes[0], k);
201 ASSERT_EQ(a->sizes[1], m);
202 ASSERT_EQ(a->strides[0], m + 1);
203 ASSERT_EQ(a->strides[1], 1);
204 for (int i = 0; i < k; ++i) {
205 for (int j = 0; j < m; ++j) {
206 EXPECT_EQ((a[{i, j}]), i * m + j);
207 EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
208 }
209 }
210 std::string moduleStr = R"mlir(
211 func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
212 %x = arith.constant 2 : index
213 %y = arith.constant 1 : index
214 %cst42 = arith.constant 42.0 : f32
215 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
216 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
217 return
218 }
219 )mlir";
220 DialectRegistry registry;
221 registerAllDialects(registry);
222 registerBuiltinDialectTranslation(registry);
223 registerLLVMDialectTranslation(registry);
224 MLIRContext context(registry);
225 OwningOpRef<ModuleOp> module =
226 parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
227 ASSERT_TRUE(!!module);
228 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
229 auto jitOrError = ExecutionEngine::create(op: *module);
230 ASSERT_TRUE(!!jitOrError);
231 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
232
233 llvm::Error error = jit->invoke(funcName: "rank2_memref", args: &*a, args: &*a);
234 ASSERT_TRUE(!error);
235 EXPECT_EQ(((*a)[1][2]), 42.);
236 EXPECT_EQ((a[{2, 1}]), 42.);
237}
238
239// A helper function that will be called from the JIT
240static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
241 int32_t coefficient) {
242 for (float &elt : *memref)
243 elt *= coefficient;
244}
245
246// MSAN does not work with JIT.
247#if __has_feature(memory_sanitizer)
248#define MAYBE_JITCallback DISABLED_JITCallback
249#else
250#define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback)
251#endif
252TEST(NativeMemRefJit, MAYBE_JITCallback) {
253 constexpr int k = 2;
254 constexpr int m = 2;
255 int64_t shape[] = {k, m};
256 int64_t shapeAlloc[] = {k + 1, m + 1};
257 OwningMemRef<float, 2> a(shape, shapeAlloc);
258 int count = 1;
259 for (float &elt : *a)
260 elt = count++;
261
262 std::string moduleStr = R"mlir(
263 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface }
264 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
265 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
266 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
267 return
268 }
269 )mlir";
270 DialectRegistry registry;
271 registerAllDialects(registry);
272 registerBuiltinDialectTranslation(registry);
273 registerLLVMDialectTranslation(registry);
274 MLIRContext context(registry);
275 auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
276 ASSERT_TRUE(!!module);
277 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
278 auto jitOrError = ExecutionEngine::create(op: *module);
279 ASSERT_TRUE(!!jitOrError);
280 auto jit = std::move(jitOrError.get());
281 // Define any extra symbols so they're available at runtime.
282 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
283 llvm::orc::SymbolMap symbolMap;
284 symbolMap[interner("_mlir_ciface_callback")] = {
285 llvm::orc::ExecutorAddr::fromPtr(Ptr: memrefMultiply),
286 llvm::JITSymbolFlags::Exported};
287 return symbolMap;
288 });
289
290 int32_t coefficient = 3.;
291 llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient);
292 ASSERT_TRUE(!error);
293 count = 1;
294 for (float elt : *a)
295 ASSERT_EQ(elt, coefficient * count++);
296}
297
298#endif // _WIN32
299

source code of mlir/unittests/ExecutionEngine/Invoke.cpp