1//===- Invoke.cpp ------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
10#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
11#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
12#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
13#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
14#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Linalg/Passes.h"
17#include "mlir/ExecutionEngine/CRunnerUtils.h"
18#include "mlir/ExecutionEngine/ExecutionEngine.h"
19#include "mlir/ExecutionEngine/MemRefUtils.h"
20#include "mlir/ExecutionEngine/RunnerUtils.h"
21#include "mlir/IR/MLIRContext.h"
22#include "mlir/InitAllDialects.h"
23#include "mlir/Parser/Parser.h"
24#include "mlir/Pass/PassManager.h"
25#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
26#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
27#include "mlir/Target/LLVMIR/Export.h"
28#include "llvm/Support/TargetSelect.h"
29#include "llvm/Support/raw_ostream.h"
30
31#include "gmock/gmock.h"
32
33// SPARC currently lacks JIT support.
34#ifdef __sparc__
35#define SKIP_WITHOUT_JIT(x) DISABLED_##x
36#else
37#define SKIP_WITHOUT_JIT(x) x
38#endif
39
40using namespace mlir;
41
42// The JIT isn't supported on Windows at that time
43#ifndef _WIN32
44
45static struct LLVMInitializer {
46 LLVMInitializer() {
47 llvm::InitializeNativeTarget();
48 llvm::InitializeNativeTargetAsmPrinter();
49 }
50} initializer;
51
52/// Simple conversion pipeline for the purpose of testing sources written in
53/// dialects lowering to LLVM Dialect.
54static LogicalResult lowerToLLVMDialect(ModuleOp module) {
55 PassManager pm(module->getName());
56 pm.addPass(pass: mlir::createFinalizeMemRefToLLVMConversionPass());
57 pm.addNestedPass<func::FuncOp>(pass: mlir::createArithToLLVMConversionPass());
58 pm.addPass(pass: mlir::createConvertFuncToLLVMPass());
59 pm.addPass(pass: mlir::createReconcileUnrealizedCastsPass());
60 return pm.run(op: module);
61}
62
63TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) {
64#ifdef __s390__
65 std::string moduleStr = R"mlir(
66 func.func @foo(%arg0 : i32 {llvm.signext}) -> (i32 {llvm.signext}) attributes { llvm.emit_c_interface } {
67 %res = arith.addi %arg0, %arg0 : i32
68 return %res : i32
69 }
70 )mlir";
71#else
72 std::string moduleStr = R"mlir(
73 func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
74 %res = arith.addi %arg0, %arg0 : i32
75 return %res : i32
76 }
77 )mlir";
78#endif
79 DialectRegistry registry;
80 registerAllDialects(registry);
81 registerBuiltinDialectTranslation(registry);
82 registerLLVMDialectTranslation(registry);
83 MLIRContext context(registry);
84 OwningOpRef<ModuleOp> module =
85 parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
86 ASSERT_TRUE(!!module);
87 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
88 auto jitOrError = ExecutionEngine::create(op: *module);
89 ASSERT_TRUE(!!jitOrError);
90 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
91 // The result of the function must be passed as output argument.
92 int result = 0;
93 llvm::Error error =
94 jit->invoke(funcName: "foo", args: 42, args: ExecutionEngine::Result<int>(result));
95 ASSERT_TRUE(!error);
96 ASSERT_EQ(result, 42 + 42);
97}
98
99TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) {
100 std::string moduleStr = R"mlir(
101 func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
102 %res = arith.subf %arg0, %arg1 : f32
103 return %res : f32
104 }
105 )mlir";
106 DialectRegistry registry;
107 registerAllDialects(registry);
108 registerBuiltinDialectTranslation(registry);
109 registerLLVMDialectTranslation(registry);
110 MLIRContext context(registry);
111 OwningOpRef<ModuleOp> module =
112 parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
113 ASSERT_TRUE(!!module);
114 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
115 auto jitOrError = ExecutionEngine::create(op: *module);
116 ASSERT_TRUE(!!jitOrError);
117 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
118 // The result of the function must be passed as output argument.
119 float result = -1;
120 llvm::Error error =
121 jit->invoke(funcName: "foo", args: 43.0f, args: 1.0f, args: ExecutionEngine::result(t&: result));
122 ASSERT_TRUE(!error);
123 ASSERT_EQ(result, 42.f);
124}
125
126TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) {
127 OwningMemRef<float, 0> a({});
128 a[{}] = 42.;
129 ASSERT_EQ(*a->data, 42);
130 a[{}] = 0;
131 std::string moduleStr = R"mlir(
132 func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
133 %cst42 = arith.constant 42.0 : f32
134 memref.store %cst42, %arg0[] : memref<f32>
135 return
136 }
137 )mlir";
138 DialectRegistry registry;
139 registerAllDialects(registry);
140 registerBuiltinDialectTranslation(registry);
141 registerLLVMDialectTranslation(registry);
142 MLIRContext context(registry);
143 auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
144 ASSERT_TRUE(!!module);
145 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
146 auto jitOrError = ExecutionEngine::create(op: *module);
147 ASSERT_TRUE(!!jitOrError);
148 auto jit = std::move(jitOrError.get());
149
150 llvm::Error error = jit->invoke(funcName: "zero_ranked", args: &*a);
151 ASSERT_TRUE(!error);
152 EXPECT_EQ((a[{}]), 42.);
153 for (float &elt : *a)
154 EXPECT_EQ(&elt, &(a[{}]));
155}
156
157TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) {
158 int64_t shape[] = {9};
159 OwningMemRef<float, 1> a(shape);
160 int count = 1;
161 for (float &elt : *a) {
162 EXPECT_EQ(&elt, &(a[{count - 1}]));
163 elt = count++;
164 }
165
166 std::string moduleStr = R"mlir(
167 func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
168 %cst42 = arith.constant 42.0 : f32
169 %cst5 = arith.constant 5 : index
170 memref.store %cst42, %arg0[%cst5] : memref<?xf32>
171 return
172 }
173 )mlir";
174 DialectRegistry registry;
175 registerAllDialects(registry);
176 registerBuiltinDialectTranslation(registry);
177 registerLLVMDialectTranslation(registry);
178 MLIRContext context(registry);
179 auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
180 ASSERT_TRUE(!!module);
181 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
182 auto jitOrError = ExecutionEngine::create(op: *module);
183 ASSERT_TRUE(!!jitOrError);
184 auto jit = std::move(jitOrError.get());
185
186 llvm::Error error = jit->invoke(funcName: "one_ranked", args: &*a);
187 ASSERT_TRUE(!error);
188 count = 1;
189 for (float &elt : *a) {
190 if (count == 6)
191 EXPECT_EQ(elt, 42.);
192 else
193 EXPECT_EQ(elt, count);
194 count++;
195 }
196}
197
198TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
199 constexpr int k = 3;
200 constexpr int m = 7;
201 // Prepare arguments beforehand.
202 auto init = [=](float &elt, ArrayRef<int64_t> indices) {
203 assert(indices.size() == 2);
204 elt = m * indices[0] + indices[1];
205 };
206 int64_t shape[] = {k, m};
207 int64_t shapeAlloc[] = {k + 1, m + 1};
208 // Use a large alignment to stress the case where the memref data/basePtr are
209 // disjoint.
210 int alignment = 8192;
211 OwningMemRef<float, 2> a(shape, shapeAlloc, init, alignment);
212 ASSERT_EQ(
213 (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
214 a->data);
215 ASSERT_EQ(a->sizes[0], k);
216 ASSERT_EQ(a->sizes[1], m);
217 ASSERT_EQ(a->strides[0], m + 1);
218 ASSERT_EQ(a->strides[1], 1);
219 for (int i = 0; i < k; ++i) {
220 for (int j = 0; j < m; ++j) {
221 EXPECT_EQ((a[{i, j}]), i * m + j);
222 EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
223 }
224 }
225 std::string moduleStr = R"mlir(
226 func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
227 %x = arith.constant 2 : index
228 %y = arith.constant 1 : index
229 %cst42 = arith.constant 42.0 : f32
230 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
231 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
232 return
233 }
234 )mlir";
235 DialectRegistry registry;
236 registerAllDialects(registry);
237 registerBuiltinDialectTranslation(registry);
238 registerLLVMDialectTranslation(registry);
239 MLIRContext context(registry);
240 OwningOpRef<ModuleOp> module =
241 parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
242 ASSERT_TRUE(!!module);
243 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
244 auto jitOrError = ExecutionEngine::create(op: *module);
245 ASSERT_TRUE(!!jitOrError);
246 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
247
248 llvm::Error error = jit->invoke(funcName: "rank2_memref", args: &*a, args: &*a);
249 ASSERT_TRUE(!error);
250 EXPECT_EQ(((*a)[1][2]), 42.);
251 EXPECT_EQ((a[{2, 1}]), 42.);
252}
253
254// A helper function that will be called from the JIT
255static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
256 int32_t coefficient) {
257 for (float &elt : *memref)
258 elt *= coefficient;
259}
260
261// MSAN does not work with JIT.
262#if __has_feature(memory_sanitizer)
263#define MAYBE_JITCallback DISABLED_JITCallback
264#else
265#define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback)
266#endif
267TEST(NativeMemRefJit, MAYBE_JITCallback) {
268 constexpr int k = 2;
269 constexpr int m = 2;
270 int64_t shape[] = {k, m};
271 int64_t shapeAlloc[] = {k + 1, m + 1};
272 OwningMemRef<float, 2> a(shape, shapeAlloc);
273 int count = 1;
274 for (float &elt : *a)
275 elt = count++;
276
277#ifdef __s390__
278 std::string moduleStr = R"mlir(
279 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface }
280 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } {
281 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
282 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
283 return
284 }
285 )mlir";
286#else
287 std::string moduleStr = R"mlir(
288 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface }
289 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
290 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
291 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
292 return
293 }
294 )mlir";
295#endif
296
297 DialectRegistry registry;
298 registerAllDialects(registry);
299 registerBuiltinDialectTranslation(registry);
300 registerLLVMDialectTranslation(registry);
301 MLIRContext context(registry);
302 auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context);
303 ASSERT_TRUE(!!module);
304 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
305 auto jitOrError = ExecutionEngine::create(op: *module);
306 ASSERT_TRUE(!!jitOrError);
307 auto jit = std::move(jitOrError.get());
308 // Define any extra symbols so they're available at runtime.
309 jit->registerSymbols(symbolMap: [&](llvm::orc::MangleAndInterner interner) {
310 llvm::orc::SymbolMap symbolMap;
311 symbolMap[interner("_mlir_ciface_callback")] = {
312 llvm::orc::ExecutorAddr::fromPtr(Ptr: memrefMultiply),
313 llvm::JITSymbolFlags::Exported};
314 return symbolMap;
315 });
316
317 int32_t coefficient = 3.;
318 llvm::Error error = jit->invoke(funcName: "caller_for_callback", args: &*a, args: coefficient);
319 ASSERT_TRUE(!error);
320 count = 1;
321 for (float elt : *a)
322 ASSERT_EQ(elt, coefficient * count++);
323}
324
325#endif // _WIN32
326

source code of mlir/unittests/ExecutionEngine/Invoke.cpp