1 | //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This is a library that provides a shared implementation for command line |
10 | // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM |
11 | // IR before JIT-compiling and executing the latter. |
12 | // |
13 | // The translation can be customized by providing an MLIR to MLIR |
14 | // transformation. |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/ExecutionEngine/JitRunner.h" |
18 | |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/ExecutionEngine/ExecutionEngine.h" |
21 | #include "mlir/ExecutionEngine/OptUtils.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/MLIRContext.h" |
24 | #include "mlir/Parser/Parser.h" |
25 | #include "mlir/Support/FileUtilities.h" |
26 | #include "mlir/Tools/ParseUtilities.h" |
27 | |
28 | #include "llvm/ADT/STLExtras.h" |
29 | #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" |
30 | #include "llvm/ExecutionEngine/Orc/LLJIT.h" |
31 | #include "llvm/IR/IRBuilder.h" |
32 | #include "llvm/IR/LLVMContext.h" |
33 | #include "llvm/IR/LegacyPassNameParser.h" |
34 | #include "llvm/Support/CommandLine.h" |
35 | #include "llvm/Support/Debug.h" |
36 | #include "llvm/Support/FileUtilities.h" |
37 | #include "llvm/Support/SourceMgr.h" |
38 | #include "llvm/Support/StringSaver.h" |
39 | #include "llvm/Support/ToolOutputFile.h" |
40 | #include <cstdint> |
41 | #include <numeric> |
42 | #include <optional> |
43 | #include <utility> |
44 | |
45 | #define DEBUG_TYPE "jit-runner" |
46 | |
47 | using namespace mlir; |
48 | using llvm::Error; |
49 | |
50 | namespace { |
51 | /// This options struct prevents the need for global static initializers, and |
52 | /// is only initialized if the JITRunner is invoked. |
53 | struct Options { |
54 | llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, |
55 | llvm::cl::desc("<input file>" ), |
56 | llvm::cl::init(Val: "-" )}; |
57 | llvm::cl::opt<std::string> mainFuncName{ |
58 | "e" , llvm::cl::desc("The function to be called" ), |
59 | llvm::cl::value_desc("<function name>" ), llvm::cl::init(Val: "main" )}; |
60 | llvm::cl::opt<std::string> mainFuncType{ |
61 | "entry-point-result" , |
62 | llvm::cl::desc("Textual description of the function type to be called" ), |
63 | llvm::cl::value_desc("f32 | i32 | i64 | void" ), llvm::cl::init(Val: "f32" )}; |
64 | |
65 | llvm::cl::OptionCategory optFlags{"opt-like flags" }; |
66 | |
67 | // CLI variables for -On options. |
68 | llvm::cl::opt<bool> optO0{"O0" , |
69 | llvm::cl::desc("Run opt passes and codegen at O0" ), |
70 | llvm::cl::cat(optFlags)}; |
71 | llvm::cl::opt<bool> optO1{"O1" , |
72 | llvm::cl::desc("Run opt passes and codegen at O1" ), |
73 | llvm::cl::cat(optFlags)}; |
74 | llvm::cl::opt<bool> optO2{"O2" , |
75 | llvm::cl::desc("Run opt passes and codegen at O2" ), |
76 | llvm::cl::cat(optFlags)}; |
77 | llvm::cl::opt<bool> optO3{"O3" , |
78 | llvm::cl::desc("Run opt passes and codegen at O3" ), |
79 | llvm::cl::cat(optFlags)}; |
80 | |
81 | llvm::cl::list<std::string> mAttrs{ |
82 | "mattr" , llvm::cl::MiscFlags::CommaSeparated, |
83 | llvm::cl::desc("Target specific attributes (-mattr=help for details)" ), |
84 | llvm::cl::value_desc("a1,+a2,-a3,..." ), llvm::cl::cat(optFlags)}; |
85 | |
86 | llvm::cl::opt<std::string> mArch{ |
87 | "march" , |
88 | llvm::cl::desc("Architecture to generate code for (see --version)" )}; |
89 | |
90 | llvm::cl::OptionCategory clOptionsCategory{"linking options" }; |
91 | llvm::cl::list<std::string> clSharedLibs{ |
92 | "shared-libs" , llvm::cl::desc("Libraries to link dynamically" ), |
93 | llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)}; |
94 | |
95 | /// CLI variables for debugging. |
96 | llvm::cl::opt<bool> dumpObjectFile{ |
97 | "dump-object-file" , |
98 | llvm::cl::desc("Dump JITted-compiled object to file specified with " |
99 | "-object-filename (<input file>.o by default)." )}; |
100 | |
101 | llvm::cl::opt<std::string> objectFilename{ |
102 | "object-filename" , |
103 | llvm::cl::desc("Dump JITted-compiled object to file <input file>.o" )}; |
104 | |
105 | llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit" , |
106 | llvm::cl::desc("Report host JIT support" ), |
107 | llvm::cl::Hidden}; |
108 | |
109 | llvm::cl::opt<bool> noImplicitModule{ |
110 | "no-implicit-module" , |
111 | llvm::cl::desc( |
112 | "Disable implicit addition of a top-level module op during parsing" ), |
113 | llvm::cl::init(Val: false)}; |
114 | }; |
115 | |
116 | struct CompileAndExecuteConfig { |
117 | /// LLVM module transformer that is passed to ExecutionEngine. |
118 | std::function<llvm::Error(llvm::Module *)> transformer; |
119 | |
120 | /// A custom function that is passed to ExecutionEngine. It processes MLIR |
121 | /// module and creates LLVM IR module. |
122 | llvm::function_ref<std::unique_ptr<llvm::Module>(Operation *, |
123 | llvm::LLVMContext &)> |
124 | llvmModuleBuilder; |
125 | |
126 | /// A custom function that is passed to ExecutinEngine to register symbols at |
127 | /// runtime. |
128 | llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> |
129 | runtimeSymbolMap; |
130 | }; |
131 | |
132 | } // namespace |
133 | |
134 | static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename, |
135 | bool insertImplicitModule, |
136 | MLIRContext *context) { |
137 | // Set up the input file. |
138 | std::string errorMessage; |
139 | auto file = openInputFile(inputFilename, errorMessage: &errorMessage); |
140 | if (!file) { |
141 | llvm::errs() << errorMessage << "\n" ; |
142 | return nullptr; |
143 | } |
144 | |
145 | auto sourceMgr = std::make_shared<llvm::SourceMgr>(); |
146 | sourceMgr->AddNewSourceBuffer(F: std::move(file), IncludeLoc: SMLoc()); |
147 | OwningOpRef<Operation *> module = |
148 | parseSourceFileForTool(sourceMgr, config: context, insertImplicitModule); |
149 | if (!module) |
150 | return nullptr; |
151 | if (!module.get()->hasTrait<OpTrait::SymbolTable>()) { |
152 | llvm::errs() << "Error: top-level op must be a symbol table.\n" ; |
153 | return nullptr; |
154 | } |
155 | return module; |
156 | } |
157 | |
158 | static inline Error makeStringError(const Twine &message) { |
159 | return llvm::make_error<llvm::StringError>(Args: message.str(), |
160 | Args: llvm::inconvertibleErrorCode()); |
161 | } |
162 | |
163 | static std::optional<unsigned> getCommandLineOptLevel(Options &options) { |
164 | std::optional<unsigned> optLevel; |
165 | SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ |
166 | options.optO0, options.optO1, options.optO2, options.optO3}; |
167 | |
168 | // Determine if there is an optimization flag present. |
169 | for (unsigned j = 0; j < 4; ++j) { |
170 | auto &flag = optFlags[j].get(); |
171 | if (flag) { |
172 | optLevel = j; |
173 | break; |
174 | } |
175 | } |
176 | return optLevel; |
177 | } |
178 | |
179 | // JIT-compile the given module and run "entryPoint" with "args" as arguments. |
180 | static Error |
181 | compileAndExecute(Options &options, Operation *module, StringRef entryPoint, |
182 | CompileAndExecuteConfig config, void **args, |
183 | std::unique_ptr<llvm::TargetMachine> tm = nullptr) { |
184 | std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel; |
185 | if (auto clOptLevel = getCommandLineOptLevel(options)) |
186 | jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel); |
187 | |
188 | SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(), |
189 | options.clSharedLibs.end()); |
190 | |
191 | mlir::ExecutionEngineOptions engineOptions; |
192 | engineOptions.llvmModuleBuilder = config.llvmModuleBuilder; |
193 | if (config.transformer) |
194 | engineOptions.transformer = config.transformer; |
195 | engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel; |
196 | engineOptions.sharedLibPaths = sharedLibs; |
197 | engineOptions.enableObjectDump = true; |
198 | auto expectedEngine = |
199 | mlir::ExecutionEngine::create(op: module, options: engineOptions, tm: std::move(tm)); |
200 | if (!expectedEngine) |
201 | return expectedEngine.takeError(); |
202 | |
203 | auto engine = std::move(*expectedEngine); |
204 | |
205 | auto expectedFPtr = engine->lookupPacked(name: entryPoint); |
206 | if (!expectedFPtr) |
207 | return expectedFPtr.takeError(); |
208 | |
209 | if (options.dumpObjectFile) |
210 | engine->dumpToObjectFile(filename: options.objectFilename.empty() |
211 | ? options.inputFilename + ".o" |
212 | : options.objectFilename); |
213 | |
214 | void (*fptr)(void **) = *expectedFPtr; |
215 | (*fptr)(args); |
216 | |
217 | return Error::success(); |
218 | } |
219 | |
220 | static Error compileAndExecuteVoidFunction( |
221 | Options &options, Operation *module, StringRef entryPoint, |
222 | CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) { |
223 | auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>( |
224 | SymbolTable::lookupSymbolIn(module, entryPoint)); |
225 | if (!mainFunction || mainFunction.empty()) |
226 | return makeStringError(message: "entry point not found" ); |
227 | |
228 | auto resultType = dyn_cast<LLVM::LLVMVoidType>( |
229 | mainFunction.getFunctionType().getReturnType()); |
230 | if (!resultType) |
231 | return makeStringError(message: "expected void function" ); |
232 | |
233 | void *empty = nullptr; |
234 | return compileAndExecute(options, module, entryPoint, config: std::move(config), |
235 | args: &empty, tm: std::move(tm)); |
236 | } |
237 | |
238 | template <typename Type> |
239 | Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); |
240 | template <> |
241 | Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { |
242 | auto resultType = dyn_cast<IntegerType>( |
243 | cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
244 | .getReturnType()); |
245 | if (!resultType || resultType.getWidth() != 32) |
246 | return makeStringError(message: "only single i32 function result supported" ); |
247 | return Error::success(); |
248 | } |
249 | template <> |
250 | Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { |
251 | auto resultType = dyn_cast<IntegerType>( |
252 | cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
253 | .getReturnType()); |
254 | if (!resultType || resultType.getWidth() != 64) |
255 | return makeStringError(message: "only single i64 function result supported" ); |
256 | return Error::success(); |
257 | } |
258 | template <> |
259 | Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { |
260 | if (!isa<Float32Type>( |
261 | cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
262 | .getReturnType())) |
263 | return makeStringError("only single f32 function result supported" ); |
264 | return Error::success(); |
265 | } |
266 | template <typename Type> |
267 | Error compileAndExecuteSingleReturnFunction( |
268 | Options &options, Operation *module, StringRef entryPoint, |
269 | CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) { |
270 | auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>( |
271 | SymbolTable::lookupSymbolIn(module, entryPoint)); |
272 | if (!mainFunction || mainFunction.isExternal()) |
273 | return makeStringError(message: "entry point not found" ); |
274 | |
275 | if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
276 | .getNumParams() != 0) |
277 | return makeStringError(message: "function inputs not supported" ); |
278 | |
279 | if (Error error = checkCompatibleReturnType<Type>(mainFunction)) |
280 | return error; |
281 | |
282 | Type res; |
283 | struct { |
284 | void *data; |
285 | } data; |
286 | data.data = &res; |
287 | if (auto error = |
288 | compileAndExecute(options, module, entryPoint, config: std::move(config), |
289 | args: (void **)&data, tm: std::move(tm))) |
290 | return error; |
291 | |
292 | // Intentional printing of the output so we can test. |
293 | llvm::outs() << res << '\n'; |
294 | |
295 | return Error::success(); |
296 | } |
297 | |
298 | /// Entry point for all CPU runners. Expects the common argc/argv arguments for |
299 | /// standard C++ main functions. |
300 | int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, |
301 | JitRunnerConfig config) { |
302 | llvm::ExitOnError exitOnErr; |
303 | |
304 | // Create the options struct containing the command line options for the |
305 | // runner. This must come before the command line options are parsed. |
306 | Options options; |
307 | llvm::cl::ParseCommandLineOptions(argc, argv, Overview: "MLIR CPU execution driver\n" ); |
308 | |
309 | if (options.hostSupportsJit) { |
310 | auto j = llvm::orc::LLJITBuilder().create(); |
311 | if (j) |
312 | llvm::outs() << "true\n" ; |
313 | else { |
314 | llvm::outs() << "false\n" ; |
315 | exitOnErr(j.takeError()); |
316 | } |
317 | return 0; |
318 | } |
319 | |
320 | std::optional<unsigned> optLevel = getCommandLineOptLevel(options); |
321 | SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ |
322 | options.optO0, options.optO1, options.optO2, options.optO3}; |
323 | |
324 | MLIRContext context(registry); |
325 | |
326 | auto m = parseMLIRInput(inputFilename: options.inputFilename, insertImplicitModule: !options.noImplicitModule, |
327 | context: &context); |
328 | if (!m) { |
329 | llvm::errs() << "could not parse the input IR\n" ; |
330 | return 1; |
331 | } |
332 | |
333 | JitRunnerOptions runnerOptions{.mainFuncName: options.mainFuncName, .mainFuncType: options.mainFuncType}; |
334 | if (config.mlirTransformer) |
335 | if (failed(result: config.mlirTransformer(m.get(), runnerOptions))) |
336 | return EXIT_FAILURE; |
337 | |
338 | auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); |
339 | if (!tmBuilderOrError) { |
340 | llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n" ; |
341 | return EXIT_FAILURE; |
342 | } |
343 | |
344 | // Configure TargetMachine builder based on the command line options |
345 | llvm::SubtargetFeatures features; |
346 | if (!options.mAttrs.empty()) { |
347 | for (StringRef attr : options.mAttrs) |
348 | features.AddFeature(String: attr); |
349 | tmBuilderOrError->addFeatures(FeatureVec: features.getFeatures()); |
350 | } |
351 | |
352 | if (!options.mArch.empty()) { |
353 | tmBuilderOrError->getTargetTriple().setArchName(options.mArch); |
354 | } |
355 | |
356 | // Build TargetMachine |
357 | auto tmOrError = tmBuilderOrError->createTargetMachine(); |
358 | |
359 | if (!tmOrError) { |
360 | llvm::errs() << "Failed to create a TargetMachine for the host\n" ; |
361 | exitOnErr(tmOrError.takeError()); |
362 | } |
363 | |
364 | LLVM_DEBUG({ |
365 | llvm::dbgs() << " JITTargetMachineBuilder is " |
366 | << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError, |
367 | "\n" ); |
368 | }); |
369 | |
370 | CompileAndExecuteConfig compileAndExecuteConfig; |
371 | if (optLevel) { |
372 | compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer( |
373 | optLevel: *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); |
374 | } |
375 | compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder; |
376 | compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap; |
377 | |
378 | // Get the function used to compile and execute the module. |
379 | using CompileAndExecuteFnT = |
380 | Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig, |
381 | std::unique_ptr<llvm::TargetMachine> tm); |
382 | auto compileAndExecuteFn = |
383 | StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) |
384 | .Case(S: "i32" , Value: compileAndExecuteSingleReturnFunction<int32_t>) |
385 | .Case(S: "i64" , Value: compileAndExecuteSingleReturnFunction<int64_t>) |
386 | .Case(S: "f32" , Value: compileAndExecuteSingleReturnFunction<float>) |
387 | .Case(S: "void" , Value: compileAndExecuteVoidFunction) |
388 | .Default(Value: nullptr); |
389 | |
390 | Error error = compileAndExecuteFn |
391 | ? compileAndExecuteFn( |
392 | options, m.get(), options.mainFuncName.getValue(), |
393 | compileAndExecuteConfig, std::move(tmOrError.get())) |
394 | : makeStringError(message: "unsupported function type" ); |
395 | |
396 | int exitCode = EXIT_SUCCESS; |
397 | llvm::handleAllErrors(E: std::move(error), |
398 | Handlers: [&exitCode](const llvm::ErrorInfoBase &info) { |
399 | llvm::errs() << "Error: " ; |
400 | info.log(OS&: llvm::errs()); |
401 | llvm::errs() << '\n'; |
402 | exitCode = EXIT_FAILURE; |
403 | }); |
404 | |
405 | return exitCode; |
406 | } |
407 | |