| 1 | //===- Invoke.cpp ------------------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
| 10 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
| 11 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
| 12 | #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" |
| 13 | #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
| 14 | #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
| 15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 16 | #include "mlir/Dialect/Linalg/Passes.h" |
| 17 | #include "mlir/ExecutionEngine/CRunnerUtils.h" |
| 18 | #include "mlir/ExecutionEngine/ExecutionEngine.h" |
| 19 | #include "mlir/ExecutionEngine/MemRefUtils.h" |
| 20 | #include "mlir/ExecutionEngine/RunnerUtils.h" |
| 21 | #include "mlir/IR/MLIRContext.h" |
| 22 | #include "mlir/InitAllDialects.h" |
| 23 | #include "mlir/Parser/Parser.h" |
| 24 | #include "mlir/Pass/PassManager.h" |
| 25 | #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
| 26 | #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
| 27 | #include "mlir/Target/LLVMIR/Export.h" |
| 28 | #include "llvm/Support/TargetSelect.h" |
| 29 | #include "llvm/Support/raw_ostream.h" |
| 30 | |
| 31 | #include "gmock/gmock.h" |
| 32 | |
| 33 | // SPARC currently lacks JIT support. |
| 34 | #ifdef __sparc__ |
| 35 | #define SKIP_WITHOUT_JIT(x) DISABLED_##x |
| 36 | #else |
| 37 | #define SKIP_WITHOUT_JIT(x) x |
| 38 | #endif |
| 39 | |
| 40 | using namespace mlir; |
| 41 | |
| 42 | // The JIT isn't supported on Windows at that time |
| 43 | #ifndef _WIN32 |
| 44 | |
| 45 | static struct LLVMInitializer { |
| 46 | LLVMInitializer() { |
| 47 | llvm::InitializeNativeTarget(); |
| 48 | llvm::InitializeNativeTargetAsmPrinter(); |
| 49 | } |
| 50 | } initializer; |
| 51 | |
| 52 | /// Simple conversion pipeline for the purpose of testing sources written in |
| 53 | /// dialects lowering to LLVM Dialect. |
| 54 | static LogicalResult lowerToLLVMDialect(ModuleOp module) { |
| 55 | PassManager pm(module->getName()); |
| 56 | pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); |
| 57 | pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass()); |
| 58 | pm.addPass(mlir::createConvertFuncToLLVMPass()); |
| 59 | pm.addPass(mlir::createReconcileUnrealizedCastsPass()); |
| 60 | return pm.run(op: module); |
| 61 | } |
| 62 | |
| 63 | TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) { |
| 64 | #ifdef __s390__ |
| 65 | std::string moduleStr = R"mlir( |
| 66 | func.func @foo(%arg0 : i32 {llvm.signext}) -> (i32 {llvm.signext}) attributes { llvm.emit_c_interface } { |
| 67 | %res = arith.addi %arg0, %arg0 : i32 |
| 68 | return %res : i32 |
| 69 | } |
| 70 | )mlir" ; |
| 71 | #else |
| 72 | std::string moduleStr = R"mlir( |
| 73 | func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { |
| 74 | %res = arith.addi %arg0, %arg0 : i32 |
| 75 | return %res : i32 |
| 76 | } |
| 77 | )mlir" ; |
| 78 | #endif |
| 79 | DialectRegistry registry; |
| 80 | registerAllDialects(registry); |
| 81 | registerBuiltinDialectTranslation(registry); |
| 82 | registerLLVMDialectTranslation(registry); |
| 83 | MLIRContext context(registry); |
| 84 | OwningOpRef<ModuleOp> module = |
| 85 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
| 86 | ASSERT_TRUE(!!module); |
| 87 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
| 88 | auto jitOrError = ExecutionEngine::create(op: *module); |
| 89 | ASSERT_TRUE(!!jitOrError); |
| 90 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
| 91 | // The result of the function must be passed as output argument. |
| 92 | int result = 0; |
| 93 | llvm::Error error = |
| 94 | jit->invoke(funcName: "foo" , args: 42, args: ExecutionEngine::Result<int>(result)); |
| 95 | ASSERT_TRUE(!error); |
| 96 | ASSERT_EQ(result, 42 + 42); |
| 97 | } |
| 98 | |
| 99 | TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) { |
| 100 | std::string moduleStr = R"mlir( |
| 101 | func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { |
| 102 | %res = arith.subf %arg0, %arg1 : f32 |
| 103 | return %res : f32 |
| 104 | } |
| 105 | )mlir" ; |
| 106 | DialectRegistry registry; |
| 107 | registerAllDialects(registry); |
| 108 | registerBuiltinDialectTranslation(registry); |
| 109 | registerLLVMDialectTranslation(registry); |
| 110 | MLIRContext context(registry); |
| 111 | OwningOpRef<ModuleOp> module = |
| 112 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
| 113 | ASSERT_TRUE(!!module); |
| 114 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
| 115 | auto jitOrError = ExecutionEngine::create(op: *module); |
| 116 | ASSERT_TRUE(!!jitOrError); |
| 117 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
| 118 | // The result of the function must be passed as output argument. |
| 119 | float result = -1; |
| 120 | llvm::Error error = |
| 121 | jit->invoke(funcName: "foo" , args: 43.0f, args: 1.0f, args: ExecutionEngine::result(t&: result)); |
| 122 | ASSERT_TRUE(!error); |
| 123 | ASSERT_EQ(result, 42.f); |
| 124 | } |
| 125 | |
| 126 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) { |
| 127 | OwningMemRef<float, 0> a({}); |
| 128 | a[{}] = 42.; |
| 129 | ASSERT_EQ(*a->data, 42); |
| 130 | a[{}] = 0; |
| 131 | std::string moduleStr = R"mlir( |
| 132 | func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { |
| 133 | %cst42 = arith.constant 42.0 : f32 |
| 134 | memref.store %cst42, %arg0[] : memref<f32> |
| 135 | return |
| 136 | } |
| 137 | )mlir" ; |
| 138 | DialectRegistry registry; |
| 139 | registerAllDialects(registry); |
| 140 | registerBuiltinDialectTranslation(registry); |
| 141 | registerLLVMDialectTranslation(registry); |
| 142 | MLIRContext context(registry); |
| 143 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
| 144 | ASSERT_TRUE(!!module); |
| 145 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
| 146 | auto jitOrError = ExecutionEngine::create(op: *module); |
| 147 | ASSERT_TRUE(!!jitOrError); |
| 148 | auto jit = std::move(jitOrError.get()); |
| 149 | |
| 150 | llvm::Error error = jit->invoke("zero_ranked" , &*a); |
| 151 | ASSERT_TRUE(!error); |
| 152 | EXPECT_EQ((a[{}]), 42.); |
| 153 | for (float &elt : *a) |
| 154 | EXPECT_EQ(&elt, &(a[{}])); |
| 155 | } |
| 156 | |
| 157 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) { |
| 158 | int64_t shape[] = {9}; |
| 159 | OwningMemRef<float, 1> a(shape); |
| 160 | int count = 1; |
| 161 | for (float &elt : *a) { |
| 162 | EXPECT_EQ(&elt, &(a[{count - 1}])); |
| 163 | elt = count++; |
| 164 | } |
| 165 | |
| 166 | std::string moduleStr = R"mlir( |
| 167 | func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { |
| 168 | %cst42 = arith.constant 42.0 : f32 |
| 169 | %cst5 = arith.constant 5 : index |
| 170 | memref.store %cst42, %arg0[%cst5] : memref<?xf32> |
| 171 | return |
| 172 | } |
| 173 | )mlir" ; |
| 174 | DialectRegistry registry; |
| 175 | registerAllDialects(registry); |
| 176 | registerBuiltinDialectTranslation(registry); |
| 177 | registerLLVMDialectTranslation(registry); |
| 178 | MLIRContext context(registry); |
| 179 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
| 180 | ASSERT_TRUE(!!module); |
| 181 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
| 182 | auto jitOrError = ExecutionEngine::create(op: *module); |
| 183 | ASSERT_TRUE(!!jitOrError); |
| 184 | auto jit = std::move(jitOrError.get()); |
| 185 | |
| 186 | llvm::Error error = jit->invoke("one_ranked" , &*a); |
| 187 | ASSERT_TRUE(!error); |
| 188 | count = 1; |
| 189 | for (float &elt : *a) { |
| 190 | if (count == 6) |
| 191 | EXPECT_EQ(elt, 42.); |
| 192 | else |
| 193 | EXPECT_EQ(elt, count); |
| 194 | count++; |
| 195 | } |
| 196 | } |
| 197 | |
| 198 | TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) { |
| 199 | constexpr int k = 3; |
| 200 | constexpr int m = 7; |
| 201 | // Prepare arguments beforehand. |
| 202 | auto init = [=](float &elt, ArrayRef<int64_t> indices) { |
| 203 | assert(indices.size() == 2); |
| 204 | elt = m * indices[0] + indices[1]; |
| 205 | }; |
| 206 | int64_t shape[] = {k, m}; |
| 207 | int64_t shapeAlloc[] = {k + 1, m + 1}; |
| 208 | OwningMemRef<float, 2> a(shape, shapeAlloc, init); |
| 209 | ASSERT_EQ(a->sizes[0], k); |
| 210 | ASSERT_EQ(a->sizes[1], m); |
| 211 | ASSERT_EQ(a->strides[0], m + 1); |
| 212 | ASSERT_EQ(a->strides[1], 1); |
| 213 | for (int i = 0; i < k; ++i) { |
| 214 | for (int j = 0; j < m; ++j) { |
| 215 | EXPECT_EQ((a[{i, j}]), i * m + j); |
| 216 | EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j])); |
| 217 | } |
| 218 | } |
| 219 | std::string moduleStr = R"mlir( |
| 220 | func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { |
| 221 | %x = arith.constant 2 : index |
| 222 | %y = arith.constant 1 : index |
| 223 | %cst42 = arith.constant 42.0 : f32 |
| 224 | memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> |
| 225 | memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> |
| 226 | return |
| 227 | } |
| 228 | )mlir" ; |
| 229 | DialectRegistry registry; |
| 230 | registerAllDialects(registry); |
| 231 | registerBuiltinDialectTranslation(registry); |
| 232 | registerLLVMDialectTranslation(registry); |
| 233 | MLIRContext context(registry); |
| 234 | OwningOpRef<ModuleOp> module = |
| 235 | parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
| 236 | ASSERT_TRUE(!!module); |
| 237 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
| 238 | auto jitOrError = ExecutionEngine::create(op: *module); |
| 239 | ASSERT_TRUE(!!jitOrError); |
| 240 | std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); |
| 241 | |
| 242 | llvm::Error error = jit->invoke(funcName: "rank2_memref" , args: &*a, args: &*a); |
| 243 | ASSERT_TRUE(!error); |
| 244 | EXPECT_EQ(((*a)[1][2]), 42.); |
| 245 | EXPECT_EQ((a[{2, 1}]), 42.); |
| 246 | } |
| 247 | |
| 248 | // A helper function that will be called from the JIT |
| 249 | static void memrefMultiply(::StridedMemRefType<float, 2> *memref, |
| 250 | int32_t coefficient) { |
| 251 | for (float &elt : *memref) |
| 252 | elt *= coefficient; |
| 253 | } |
| 254 | |
| 255 | // MSAN does not work with JIT. |
| 256 | #if __has_feature(memory_sanitizer) |
| 257 | #define MAYBE_JITCallback DISABLED_JITCallback |
| 258 | #else |
| 259 | #define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback) |
| 260 | #endif |
| 261 | TEST(NativeMemRefJit, MAYBE_JITCallback) { |
| 262 | constexpr int k = 2; |
| 263 | constexpr int m = 2; |
| 264 | int64_t shape[] = {k, m}; |
| 265 | int64_t shapeAlloc[] = {k + 1, m + 1}; |
| 266 | OwningMemRef<float, 2> a(shape, shapeAlloc); |
| 267 | int count = 1; |
| 268 | for (float &elt : *a) |
| 269 | elt = count++; |
| 270 | |
| 271 | #ifdef __s390__ |
| 272 | std::string moduleStr = R"mlir( |
| 273 | func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } |
| 274 | func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } { |
| 275 | %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> |
| 276 | call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () |
| 277 | return |
| 278 | } |
| 279 | )mlir" ; |
| 280 | #else |
| 281 | std::string moduleStr = R"mlir( |
| 282 | func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } |
| 283 | func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { |
| 284 | %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> |
| 285 | call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () |
| 286 | return |
| 287 | } |
| 288 | )mlir" ; |
| 289 | #endif |
| 290 | |
| 291 | DialectRegistry registry; |
| 292 | registerAllDialects(registry); |
| 293 | registerBuiltinDialectTranslation(registry); |
| 294 | registerLLVMDialectTranslation(registry); |
| 295 | MLIRContext context(registry); |
| 296 | auto module = parseSourceString<ModuleOp>(sourceStr: moduleStr, config: &context); |
| 297 | ASSERT_TRUE(!!module); |
| 298 | ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); |
| 299 | auto jitOrError = ExecutionEngine::create(op: *module); |
| 300 | ASSERT_TRUE(!!jitOrError); |
| 301 | auto jit = std::move(jitOrError.get()); |
| 302 | // Define any extra symbols so they're available at runtime. |
| 303 | jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { |
| 304 | llvm::orc::SymbolMap symbolMap; |
| 305 | symbolMap[interner("_mlir_ciface_callback" )] = { |
| 306 | llvm::orc::ExecutorAddr::fromPtr(Ptr: memrefMultiply), |
| 307 | llvm::JITSymbolFlags::Exported}; |
| 308 | return symbolMap; |
| 309 | }); |
| 310 | |
| 311 | int32_t coefficient = 3.; |
| 312 | llvm::Error error = jit->invoke("caller_for_callback" , &*a, coefficient); |
| 313 | ASSERT_TRUE(!error); |
| 314 | count = 1; |
| 315 | for (float elt : *a) |
| 316 | ASSERT_EQ(elt, coefficient * count++); |
| 317 | } |
| 318 | |
| 319 | #endif // _WIN32 |
| 320 | |