1 | //===- Invoke.cpp ------------------------------------*- C++ -*-===// |
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
10 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
11 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
12 | #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" |
13 | #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
14 | #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/Linalg/Passes.h" |
17 | #include "mlir/ExecutionEngine/CRunnerUtils.h" |
18 | #include "mlir/ExecutionEngine/ExecutionEngine.h" |
19 | #include "mlir/ExecutionEngine/MemRefUtils.h" |
20 | #include "mlir/ExecutionEngine/RunnerUtils.h" |
21 | #include "mlir/IR/MLIRContext.h" |
22 | #include "mlir/InitAllDialects.h" |
23 | #include "mlir/Parser/Parser.h" |
24 | #include "mlir/Pass/PassManager.h" |
25 | #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
26 | #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
27 | #include "mlir/Target/LLVMIR/Export.h" |
28 | #include "llvm/Support/TargetSelect.h" |
29 | #include "llvm/Support/raw_ostream.h" |
30 | |
31 | #include "gmock/gmock.h" |
32 | |
33 | // SPARC currently lacks JIT support. |
34 | #ifdef __sparc__ |
35 | #define SKIP_WITHOUT_JIT(x) DISABLED_##x |
36 | #else |
37 | #define SKIP_WITHOUT_JIT(x) x |
38 | #endif |
39 | |
40 | using namespace mlir; |
41 | |
42 | // The JIT isn't supported on Windows at that time |
43 | #ifndef _WIN32 |
44 | |
45 | static struct LLVMInitializer { |
46 | LLVMInitializer() { |
47 | llvm::InitializeNativeTarget(); |
48 | llvm::InitializeNativeTargetAsmPrinter(); |
49 | } |
50 | } initializer; |
51 | |
52 | /// Simple conversion pipeline for the purpose of testing sources written in |
53 | /// dialects lowering to LLVM Dialect. |
54 | static LogicalResult lowerToLLVMDialect(ModuleOp module) { |
55 | PassManager pm(module->getName()); |
56 | pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); |
57 | pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass()); |
58 | pm.addPass(mlir::createConvertFuncToLLVMPass()); |
59 | pm.addPass(mlir::createReconcileUnrealizedCastsPass()); |
60 | return pm.run(op: module); |
61 | } |
62 | |
63 | TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) { |
64 | #ifdef __s390__ |
65 | std::string moduleStr = R"mlir( |
66 | func.func @foo(%arg0 : i32 {llvm.signext}) -> (i32 {llvm.signext}) attributes { llvm.emit_c_interface } { |
67 | %res = arith.addi %arg0, %arg0 : i32 |
68 | return %res : i32 |
69 | } |
70 | )mlir" ; |
71 | #else |
72 | std::string moduleStr = R"mlir( |
73 | func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { |
74 | %res = arith.addi %arg0, %arg0 : i32 |
75 | return %res : i32 |
76 | } |
77 | )mlir" ; |
78 | #endif |
79 | DialectRegistry registry; |
80 | registerAllDialects(registry); |
81 | registerBuiltinDialectTranslation(registry); |
82 | registerLLVMDialectTranslation(registry); |
83 | MLIRContext context(registry); |
84 | OwningOpRef<ModuleOp> module = |
85 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
86 | ASSERT_TRUE(!!module); |
87 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
88 | auto jitOrError = ExecutionEngine::create(op: *module); |
89 | ASSERT_TRUE(!!jitOrError); |
90 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
91 | // The result of the function must be passed as output argument. |
92 | int result = 0; |
93 | llvm::Error error = |
94 | jit->invoke(funcName: "foo" , args: 42, args: ExecutionEngine::Result<int>(result)); |
95 | ASSERT_TRUE(!error); |
96 | ASSERT_EQ(result, 42 + 42); |
97 | } |
98 | |
99 | TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) { |
100 | std::string moduleStr = R"mlir( |
101 | func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { |
102 | %res = arith.subf %arg0, %arg1 : f32 |
103 | return %res : f32 |
104 | } |
105 | )mlir" ; |
106 | DialectRegistry registry; |
107 | registerAllDialects(registry); |
108 | registerBuiltinDialectTranslation(registry); |
109 | registerLLVMDialectTranslation(registry); |
110 | MLIRContext context(registry); |
111 | OwningOpRef<ModuleOp> module = |
112 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
113 | ASSERT_TRUE(!!module); |
114 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
115 | auto jitOrError = ExecutionEngine::create(op: *module); |
116 | ASSERT_TRUE(!!jitOrError); |
117 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
118 | // The result of the function must be passed as output argument. |
119 | float result = -1; |
120 | llvm::Error error = |
121 | jit->invoke(funcName: "foo" , args: 43.0f, args: 1.0f, args: ExecutionEngine::result(t&: result)); |
122 | ASSERT_TRUE(!error); |
123 | ASSERT_EQ(result, 42.f); |
124 | } |
125 | |
126 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) { |
127 | OwningMemRef<float, 0> a({}); |
128 | a[{}] = 42.; |
129 | ASSERT_EQ(*a->data, 42); |
130 | a[{}] = 0; |
131 | std::string moduleStr = R"mlir( |
132 | func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { |
133 | %cst42 = arith.constant 42.0 : f32 |
134 | memref.store %cst42, %arg0[] : memref<f32> |
135 | return |
136 | } |
137 | )mlir" ; |
138 | DialectRegistry registry; |
139 | registerAllDialects(registry); |
140 | registerBuiltinDialectTranslation(registry); |
141 | registerLLVMDialectTranslation(registry); |
142 | MLIRContext context(registry); |
143 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
144 | ASSERT_TRUE(!!module); |
145 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
146 | auto jitOrError = ExecutionEngine::create(op: *module); |
147 | ASSERT_TRUE(!!jitOrError); |
148 | auto jit = std::move(jitOrError.get()); |
149 | |
150 | llvm::Error error = jit->invoke("zero_ranked" , &*a); |
151 | ASSERT_TRUE(!error); |
152 | EXPECT_EQ((a[{}]), 42.); |
153 | for (float &elt : *a) |
154 | EXPECT_EQ(&elt, &(a[{}])); |
155 | } |
156 | |
157 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) { |
158 | int64_t shape[] = {9}; |
159 | OwningMemRef<float, 1> a(shape); |
160 | int count = 1; |
161 | for (float &elt : *a) { |
162 | EXPECT_EQ(&elt, &(a[{count - 1}])); |
163 | elt = count++; |
164 | } |
165 | |
166 | std::string moduleStr = R"mlir( |
167 | func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { |
168 | %cst42 = arith.constant 42.0 : f32 |
169 | %cst5 = arith.constant 5 : index |
170 | memref.store %cst42, %arg0[%cst5] : memref<?xf32> |
171 | return |
172 | } |
173 | )mlir" ; |
174 | DialectRegistry registry; |
175 | registerAllDialects(registry); |
176 | registerBuiltinDialectTranslation(registry); |
177 | registerLLVMDialectTranslation(registry); |
178 | MLIRContext context(registry); |
179 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
180 | ASSERT_TRUE(!!module); |
181 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
182 | auto jitOrError = ExecutionEngine::create(op: *module); |
183 | ASSERT_TRUE(!!jitOrError); |
184 | auto jit = std::move(jitOrError.get()); |
185 | |
186 | llvm::Error error = jit->invoke("one_ranked" , &*a); |
187 | ASSERT_TRUE(!error); |
188 | count = 1; |
189 | for (float &elt : *a) { |
190 | if (count == 6) |
191 | EXPECT_EQ(elt, 42.); |
192 | else |
193 | EXPECT_EQ(elt, count); |
194 | count++; |
195 | } |
196 | } |
197 | |
198 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) { |
199 | constexpr int k = 3; |
200 | constexpr int m = 7; |
201 | // Prepare arguments beforehand. |
202 | auto init = [=](float &elt, ArrayRef<int64_t> indices) { |
203 | assert(indices.size() == 2); |
204 | elt = m * indices[0] + indices[1]; |
205 | }; |
206 | int64_t shape[] = {k, m}; |
207 | int64_t shapeAlloc[] = {k + 1, m + 1}; |
208 | OwningMemRef<float, 2> a(shape, shapeAlloc, init); |
209 | ASSERT_EQ(a->sizes[0], k); |
210 | ASSERT_EQ(a->sizes[1], m); |
211 | ASSERT_EQ(a->strides[0], m + 1); |
212 | ASSERT_EQ(a->strides[1], 1); |
213 | for (int i = 0; i < k; ++i) { |
214 | for (int j = 0; j < m; ++j) { |
215 | EXPECT_EQ((a[{i, j}]), i * m + j); |
216 | EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j])); |
217 | } |
218 | } |
219 | std::string moduleStr = R"mlir( |
220 | func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { |
221 | %x = arith.constant 2 : index |
222 | %y = arith.constant 1 : index |
223 | %cst42 = arith.constant 42.0 : f32 |
224 | memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> |
225 | memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> |
226 | return |
227 | } |
228 | )mlir" ; |
229 | DialectRegistry registry; |
230 | registerAllDialects(registry); |
231 | registerBuiltinDialectTranslation(registry); |
232 | registerLLVMDialectTranslation(registry); |
233 | MLIRContext context(registry); |
234 | OwningOpRef<ModuleOp> module = |
235 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
236 | ASSERT_TRUE(!!module); |
237 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
238 | auto jitOrError = ExecutionEngine::create(op: *module); |
239 | ASSERT_TRUE(!!jitOrError); |
240 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
241 | |
242 | llvm::Error error = jit->invoke(funcName: "rank2_memref" , args: &*a, args: &*a); |
243 | ASSERT_TRUE(!error); |
244 | EXPECT_EQ(((*a)[1][2]), 42.); |
245 | EXPECT_EQ((a[{2, 1}]), 42.); |
246 | } |
247 | |
248 | // A helper function that will be called from the JIT |
249 | static void memrefMultiply(::StridedMemRefType<float, 2> *memref, |
250 | int32_t coefficient) { |
251 | for (float &elt : *memref) |
252 | elt *= coefficient; |
253 | } |
254 | |
255 | // MSAN does not work with JIT. |
256 | #if __has_feature(memory_sanitizer) |
257 | #define MAYBE_JITCallback DISABLED_JITCallback |
258 | #else |
259 | #define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback) |
260 | #endif |
261 | TEST(NativeMemRefJit, MAYBE_JITCallback) { |
262 | constexpr int k = 2; |
263 | constexpr int m = 2; |
264 | int64_t shape[] = {k, m}; |
265 | int64_t shapeAlloc[] = {k + 1, m + 1}; |
266 | OwningMemRef<float, 2> a(shape, shapeAlloc); |
267 | int count = 1; |
268 | for (float &elt : *a) |
269 | elt = count++; |
270 | |
271 | #ifdef __s390__ |
272 | std::string moduleStr = R"mlir( |
273 | func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } |
274 | func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } { |
275 | %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> |
276 | call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () |
277 | return |
278 | } |
279 | )mlir" ; |
280 | #else |
281 | std::string moduleStr = R"mlir( |
282 | func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } |
283 | func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { |
284 | %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> |
285 | call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () |
286 | return |
287 | } |
288 | )mlir" ; |
289 | #endif |
290 | |
291 | DialectRegistry registry; |
292 | registerAllDialects(registry); |
293 | registerBuiltinDialectTranslation(registry); |
294 | registerLLVMDialectTranslation(registry); |
295 | MLIRContext context(registry); |
296 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
297 | ASSERT_TRUE(!!module); |
298 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
299 | auto jitOrError = ExecutionEngine::create(op: *module); |
300 | ASSERT_TRUE(!!jitOrError); |
301 | auto jit = std::move(jitOrError.get()); |
302 | // Define any extra symbols so they're available at runtime. |
303 | jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { |
304 | llvm::orc::SymbolMap symbolMap; |
305 | symbolMap[interner("_mlir_ciface_callback" )] = { |
306 | llvm::orc::ExecutorAddr::fromPtr(Ptr: memrefMultiply), |
307 | llvm::JITSymbolFlags::Exported}; |
308 | return symbolMap; |
309 | }); |
310 | |
311 | int32_t coefficient = 3.; |
312 | llvm::Error error = jit->invoke("caller_for_callback" , &*a, coefficient); |
313 | ASSERT_TRUE(!error); |
314 | count = 1; |
315 | for (float elt : *a) |
316 | ASSERT_EQ(elt, coefficient * count++); |
317 | } |
318 | |
319 | #endif // _WIN32 |
320 | |