1 | //===- Invoke.cpp ------------------------------------*- C++ -*-===// |
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
10 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
11 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
12 | #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" |
13 | #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
14 | #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/Linalg/Passes.h" |
17 | #include "mlir/ExecutionEngine/CRunnerUtils.h" |
18 | #include "mlir/ExecutionEngine/ExecutionEngine.h" |
19 | #include "mlir/ExecutionEngine/MemRefUtils.h" |
20 | #include "mlir/ExecutionEngine/RunnerUtils.h" |
21 | #include "mlir/IR/MLIRContext.h" |
22 | #include "mlir/InitAllDialects.h" |
23 | #include "mlir/Parser/Parser.h" |
24 | #include "mlir/Pass/PassManager.h" |
25 | #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
26 | #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
27 | #include "mlir/Target/LLVMIR/Export.h" |
28 | #include "llvm/Support/TargetSelect.h" |
29 | #include "llvm/Support/raw_ostream.h" |
30 | |
31 | #include "gmock/gmock.h" |
32 | |
33 | // SPARC currently lacks JIT support. |
34 | #ifdef __sparc__ |
35 | #define SKIP_WITHOUT_JIT(x) DISABLED_##x |
36 | #else |
37 | #define SKIP_WITHOUT_JIT(x) x |
38 | #endif |
39 | |
40 | using namespace mlir; |
41 | |
42 | // The JIT isn't supported on Windows at that time |
43 | #ifndef _WIN32 |
44 | |
45 | static struct LLVMInitializer { |
46 | LLVMInitializer() { |
47 | llvm::InitializeNativeTarget(); |
48 | llvm::InitializeNativeTargetAsmPrinter(); |
49 | } |
50 | } initializer; |
51 | |
52 | /// Simple conversion pipeline for the purpose of testing sources written in |
53 | /// dialects lowering to LLVM Dialect. |
54 | static LogicalResult lowerToLLVMDialect(ModuleOp module) { |
55 | PassManager pm(module->getName()); |
56 | pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); |
57 | pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass()); |
58 | pm.addPass(mlir::createConvertFuncToLLVMPass()); |
59 | pm.addPass(pass: mlir::createReconcileUnrealizedCastsPass()); |
60 | return pm.run(op: module); |
61 | } |
62 | |
63 | TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) { |
64 | std::string moduleStr = R"mlir( |
65 | func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { |
66 | %res = arith.addi %arg0, %arg0 : i32 |
67 | return %res : i32 |
68 | } |
69 | )mlir" ; |
70 | DialectRegistry registry; |
71 | registerAllDialects(registry); |
72 | registerBuiltinDialectTranslation(registry); |
73 | registerLLVMDialectTranslation(registry); |
74 | MLIRContext context(registry); |
75 | OwningOpRef<ModuleOp> module = |
76 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
77 | ASSERT_TRUE(!!module); |
78 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
79 | auto jitOrError = ExecutionEngine::create(op: *module); |
80 | ASSERT_TRUE(!!jitOrError); |
81 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
82 | // The result of the function must be passed as output argument. |
83 | int result = 0; |
84 | llvm::Error error = |
85 | jit->invoke(funcName: "foo" , args: 42, args: ExecutionEngine::Result<int>(result)); |
86 | ASSERT_TRUE(!error); |
87 | ASSERT_EQ(result, 42 + 42); |
88 | } |
89 | |
90 | TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) { |
91 | std::string moduleStr = R"mlir( |
92 | func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { |
93 | %res = arith.subf %arg0, %arg1 : f32 |
94 | return %res : f32 |
95 | } |
96 | )mlir" ; |
97 | DialectRegistry registry; |
98 | registerAllDialects(registry); |
99 | registerBuiltinDialectTranslation(registry); |
100 | registerLLVMDialectTranslation(registry); |
101 | MLIRContext context(registry); |
102 | OwningOpRef<ModuleOp> module = |
103 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
104 | ASSERT_TRUE(!!module); |
105 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
106 | auto jitOrError = ExecutionEngine::create(op: *module); |
107 | ASSERT_TRUE(!!jitOrError); |
108 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
109 | // The result of the function must be passed as output argument. |
110 | float result = -1; |
111 | llvm::Error error = |
112 | jit->invoke(funcName: "foo" , args: 43.0f, args: 1.0f, args: ExecutionEngine::result(t&: result)); |
113 | ASSERT_TRUE(!error); |
114 | ASSERT_EQ(result, 42.f); |
115 | } |
116 | |
117 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) { |
118 | OwningMemRef<float, 0> a({}); |
119 | a[{}] = 42.; |
120 | ASSERT_EQ(*a->data, 42); |
121 | a[{}] = 0; |
122 | std::string moduleStr = R"mlir( |
123 | func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { |
124 | %cst42 = arith.constant 42.0 : f32 |
125 | memref.store %cst42, %arg0[] : memref<f32> |
126 | return |
127 | } |
128 | )mlir" ; |
129 | DialectRegistry registry; |
130 | registerAllDialects(registry); |
131 | registerBuiltinDialectTranslation(registry); |
132 | registerLLVMDialectTranslation(registry); |
133 | MLIRContext context(registry); |
134 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
135 | ASSERT_TRUE(!!module); |
136 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
137 | auto jitOrError = ExecutionEngine::create(op: *module); |
138 | ASSERT_TRUE(!!jitOrError); |
139 | auto jit = std::move(jitOrError.get()); |
140 | |
141 | llvm::Error error = jit->invoke("zero_ranked" , &*a); |
142 | ASSERT_TRUE(!error); |
143 | EXPECT_EQ((a[{}]), 42.); |
144 | for (float &elt : *a) |
145 | EXPECT_EQ(&elt, &(a[{}])); |
146 | } |
147 | |
148 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) { |
149 | int64_t shape[] = {9}; |
150 | OwningMemRef<float, 1> a(shape); |
151 | int count = 1; |
152 | for (float &elt : *a) { |
153 | EXPECT_EQ(&elt, &(a[{count - 1}])); |
154 | elt = count++; |
155 | } |
156 | |
157 | std::string moduleStr = R"mlir( |
158 | func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { |
159 | %cst42 = arith.constant 42.0 : f32 |
160 | %cst5 = arith.constant 5 : index |
161 | memref.store %cst42, %arg0[%cst5] : memref<?xf32> |
162 | return |
163 | } |
164 | )mlir" ; |
165 | DialectRegistry registry; |
166 | registerAllDialects(registry); |
167 | registerBuiltinDialectTranslation(registry); |
168 | registerLLVMDialectTranslation(registry); |
169 | MLIRContext context(registry); |
170 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
171 | ASSERT_TRUE(!!module); |
172 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
173 | auto jitOrError = ExecutionEngine::create(op: *module); |
174 | ASSERT_TRUE(!!jitOrError); |
175 | auto jit = std::move(jitOrError.get()); |
176 | |
177 | llvm::Error error = jit->invoke("one_ranked" , &*a); |
178 | ASSERT_TRUE(!error); |
179 | count = 1; |
180 | for (float &elt : *a) { |
181 | if (count == 6) |
182 | EXPECT_EQ(elt, 42.); |
183 | else |
184 | EXPECT_EQ(elt, count); |
185 | count++; |
186 | } |
187 | } |
188 | |
189 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) { |
190 | constexpr int k = 3; |
191 | constexpr int m = 7; |
192 | // Prepare arguments beforehand. |
193 | auto init = [=](float &elt, ArrayRef<int64_t> indices) { |
194 | assert(indices.size() == 2); |
195 | elt = m * indices[0] + indices[1]; |
196 | }; |
197 | int64_t shape[] = {k, m}; |
198 | int64_t shapeAlloc[] = {k + 1, m + 1}; |
199 | OwningMemRef<float, 2> a(shape, shapeAlloc, init); |
200 | ASSERT_EQ(a->sizes[0], k); |
201 | ASSERT_EQ(a->sizes[1], m); |
202 | ASSERT_EQ(a->strides[0], m + 1); |
203 | ASSERT_EQ(a->strides[1], 1); |
204 | for (int i = 0; i < k; ++i) { |
205 | for (int j = 0; j < m; ++j) { |
206 | EXPECT_EQ((a[{i, j}]), i * m + j); |
207 | EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j])); |
208 | } |
209 | } |
210 | std::string moduleStr = R"mlir( |
211 | func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { |
212 | %x = arith.constant 2 : index |
213 | %y = arith.constant 1 : index |
214 | %cst42 = arith.constant 42.0 : f32 |
215 | memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> |
216 | memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> |
217 | return |
218 | } |
219 | )mlir" ; |
220 | DialectRegistry registry; |
221 | registerAllDialects(registry); |
222 | registerBuiltinDialectTranslation(registry); |
223 | registerLLVMDialectTranslation(registry); |
224 | MLIRContext context(registry); |
225 | OwningOpRef<ModuleOp> module = |
226 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
227 | ASSERT_TRUE(!!module); |
228 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
229 | auto jitOrError = ExecutionEngine::create(op: *module); |
230 | ASSERT_TRUE(!!jitOrError); |
231 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
232 | |
233 | llvm::Error error = jit->invoke(funcName: "rank2_memref" , args: &*a, args: &*a); |
234 | ASSERT_TRUE(!error); |
235 | EXPECT_EQ(((*a)[1][2]), 42.); |
236 | EXPECT_EQ((a[{2, 1}]), 42.); |
237 | } |
238 | |
239 | // A helper function that will be called from the JIT |
240 | static void memrefMultiply(::StridedMemRefType<float, 2> *memref, |
241 | int32_t coefficient) { |
242 | for (float &elt : *memref) |
243 | elt *= coefficient; |
244 | } |
245 | |
246 | // MSAN does not work with JIT. |
247 | #if __has_feature(memory_sanitizer) |
248 | #define MAYBE_JITCallback DISABLED_JITCallback |
249 | #else |
250 | #define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback) |
251 | #endif |
252 | TEST(NativeMemRefJit, MAYBE_JITCallback) { |
253 | constexpr int k = 2; |
254 | constexpr int m = 2; |
255 | int64_t shape[] = {k, m}; |
256 | int64_t shapeAlloc[] = {k + 1, m + 1}; |
257 | OwningMemRef<float, 2> a(shape, shapeAlloc); |
258 | int count = 1; |
259 | for (float &elt : *a) |
260 | elt = count++; |
261 | |
262 | std::string moduleStr = R"mlir( |
263 | func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } |
264 | func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { |
265 | %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> |
266 | call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () |
267 | return |
268 | } |
269 | )mlir" ; |
270 | DialectRegistry registry; |
271 | registerAllDialects(registry); |
272 | registerBuiltinDialectTranslation(registry); |
273 | registerLLVMDialectTranslation(registry); |
274 | MLIRContext context(registry); |
275 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
276 | ASSERT_TRUE(!!module); |
277 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
278 | auto jitOrError = ExecutionEngine::create(op: *module); |
279 | ASSERT_TRUE(!!jitOrError); |
280 | auto jit = std::move(jitOrError.get()); |
281 | // Define any extra symbols so they're available at runtime. |
282 | jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { |
283 | llvm::orc::SymbolMap symbolMap; |
284 | symbolMap[interner("_mlir_ciface_callback" )] = { |
285 | llvm::orc::ExecutorAddr::fromPtr(Ptr: memrefMultiply), |
286 | llvm::JITSymbolFlags::Exported}; |
287 | return symbolMap; |
288 | }); |
289 | |
290 | int32_t coefficient = 3.; |
291 | llvm::Error error = jit->invoke("caller_for_callback" , &*a, coefficient); |
292 | ASSERT_TRUE(!error); |
293 | count = 1; |
294 | for (float elt : *a) |
295 | ASSERT_EQ(elt, coefficient * count++); |
296 | } |
297 | |
298 | #endif // _WIN32 |
299 | |