1 | //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This is a library that provides a shared implementation for command line |
10 | // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM |
11 | // IR before JIT-compiling and executing the latter. |
12 | // |
13 | // The translation can be customized by providing an MLIR to MLIR |
14 | // transformation. |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/ExecutionEngine/JitRunner.h" |
18 | |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/ExecutionEngine/ExecutionEngine.h" |
21 | #include "mlir/ExecutionEngine/OptUtils.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/MLIRContext.h" |
24 | #include "mlir/Parser/Parser.h" |
25 | #include "mlir/Support/FileUtilities.h" |
26 | #include "mlir/Tools/ParseUtilities.h" |
27 | |
28 | #include "llvm/ADT/STLExtras.h" |
29 | #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" |
30 | #include "llvm/ExecutionEngine/Orc/LLJIT.h" |
31 | #include "llvm/IR/IRBuilder.h" |
32 | #include "llvm/IR/LLVMContext.h" |
33 | #include "llvm/IR/LegacyPassNameParser.h" |
34 | #include "llvm/Support/CommandLine.h" |
35 | #include "llvm/Support/Debug.h" |
36 | #include "llvm/Support/FileUtilities.h" |
37 | #include "llvm/Support/SourceMgr.h" |
38 | #include "llvm/Support/StringSaver.h" |
39 | #include "llvm/Support/ToolOutputFile.h" |
40 | #include <cstdint> |
41 | #include <numeric> |
42 | #include <optional> |
43 | #include <utility> |
44 | |
45 | #define DEBUG_TYPE "jit-runner" |
46 | |
47 | using namespace mlir; |
48 | using llvm::Error; |
49 | |
50 | namespace { |
51 | /// This options struct prevents the need for global static initializers, and |
52 | /// is only initialized if the JITRunner is invoked. |
53 | struct Options { |
54 | llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, |
55 | llvm::cl::desc("<input file>" ), |
56 | llvm::cl::init(Val: "-" )}; |
57 | llvm::cl::opt<std::string> mainFuncName{ |
58 | "e" , llvm::cl::desc("The function to be called" ), |
59 | llvm::cl::value_desc("<function name>" ), llvm::cl::init(Val: "main" )}; |
60 | llvm::cl::opt<std::string> mainFuncType{ |
61 | "entry-point-result" , |
62 | llvm::cl::desc("Textual description of the function type to be called" ), |
63 | llvm::cl::value_desc("f32 | i32 | i64 | void" ), llvm::cl::init(Val: "f32" )}; |
64 | |
65 | llvm::cl::OptionCategory optFlags{"opt-like flags" }; |
66 | |
67 | // CLI variables for -On options. |
68 | llvm::cl::opt<bool> optO0{"O0" , |
69 | llvm::cl::desc("Run opt passes and codegen at O0" ), |
70 | llvm::cl::cat(optFlags)}; |
71 | llvm::cl::opt<bool> optO1{"O1" , |
72 | llvm::cl::desc("Run opt passes and codegen at O1" ), |
73 | llvm::cl::cat(optFlags)}; |
74 | llvm::cl::opt<bool> optO2{"O2" , |
75 | llvm::cl::desc("Run opt passes and codegen at O2" ), |
76 | llvm::cl::cat(optFlags)}; |
77 | llvm::cl::opt<bool> optO3{"O3" , |
78 | llvm::cl::desc("Run opt passes and codegen at O3" ), |
79 | llvm::cl::cat(optFlags)}; |
80 | |
81 | llvm::cl::list<std::string> mAttrs{ |
82 | "mattr" , llvm::cl::MiscFlags::CommaSeparated, |
83 | llvm::cl::desc("Target specific attributes (-mattr=help for details)" ), |
84 | llvm::cl::value_desc("a1,+a2,-a3,..." ), llvm::cl::cat(optFlags)}; |
85 | |
86 | llvm::cl::opt<std::string> mArch{ |
87 | "march" , |
88 | llvm::cl::desc("Architecture to generate code for (see --version)" )}; |
89 | |
90 | llvm::cl::OptionCategory clOptionsCategory{"linking options" }; |
91 | llvm::cl::list<std::string> clSharedLibs{ |
92 | "shared-libs" , llvm::cl::desc("Libraries to link dynamically" ), |
93 | llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)}; |
94 | |
95 | /// CLI variables for debugging. |
96 | llvm::cl::opt<bool> dumpObjectFile{ |
97 | "dump-object-file" , |
98 | llvm::cl::desc("Dump JITted-compiled object to file specified with " |
99 | "-object-filename (<input file>.o by default)." )}; |
100 | |
101 | llvm::cl::opt<std::string> objectFilename{ |
102 | "object-filename" , |
103 | llvm::cl::desc("Dump JITted-compiled object to file <input file>.o" )}; |
104 | |
105 | llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit" , |
106 | llvm::cl::desc("Report host JIT support" ), |
107 | llvm::cl::Hidden}; |
108 | |
109 | llvm::cl::opt<bool> noImplicitModule{ |
110 | "no-implicit-module" , |
111 | llvm::cl::desc( |
112 | "Disable implicit addition of a top-level module op during parsing" ), |
113 | llvm::cl::init(Val: false)}; |
114 | }; |
115 | |
116 | struct CompileAndExecuteConfig { |
117 | /// LLVM module transformer that is passed to ExecutionEngine. |
118 | std::function<llvm::Error(llvm::Module *)> transformer; |
119 | |
120 | /// A custom function that is passed to ExecutionEngine. It processes MLIR |
121 | /// module and creates LLVM IR module. |
122 | llvm::function_ref<std::unique_ptr<llvm::Module>(Operation *, |
123 | llvm::LLVMContext &)> |
124 | llvmModuleBuilder; |
125 | |
126 | /// A custom function that is passed to ExecutinEngine to register symbols at |
127 | /// runtime. |
128 | llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> |
129 | runtimeSymbolMap; |
130 | }; |
131 | |
132 | } // namespace |
133 | |
134 | static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename, |
135 | bool insertImplicitModule, |
136 | MLIRContext *context) { |
137 | // Set up the input file. |
138 | std::string errorMessage; |
139 | auto file = openInputFile(inputFilename, errorMessage: &errorMessage); |
140 | if (!file) { |
141 | llvm::errs() << errorMessage << "\n" ; |
142 | return nullptr; |
143 | } |
144 | |
145 | auto sourceMgr = std::make_shared<llvm::SourceMgr>(); |
146 | sourceMgr->AddNewSourceBuffer(F: std::move(file), IncludeLoc: SMLoc()); |
147 | OwningOpRef<Operation *> module = |
148 | parseSourceFileForTool(sourceMgr, config: context, insertImplicitModule); |
149 | if (!module) |
150 | return nullptr; |
151 | if (!module.get()->hasTrait<OpTrait::SymbolTable>()) { |
152 | llvm::errs() << "Error: top-level op must be a symbol table.\n" ; |
153 | return nullptr; |
154 | } |
155 | return module; |
156 | } |
157 | |
158 | static inline Error makeStringError(const Twine &message) { |
159 | return llvm::make_error<llvm::StringError>(Args: message.str(), |
160 | Args: llvm::inconvertibleErrorCode()); |
161 | } |
162 | |
163 | static std::optional<unsigned> getCommandLineOptLevel(Options &options) { |
164 | std::optional<unsigned> optLevel; |
165 | SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ |
166 | options.optO0, options.optO1, options.optO2, options.optO3}; |
167 | |
168 | // Determine if there is an optimization flag present. |
169 | for (unsigned j = 0; j < 4; ++j) { |
170 | auto &flag = optFlags[j].get(); |
171 | if (flag) { |
172 | optLevel = j; |
173 | break; |
174 | } |
175 | } |
176 | return optLevel; |
177 | } |
178 | |
179 | // JIT-compile the given module and run "entryPoint" with "args" as arguments. |
180 | static Error |
181 | compileAndExecute(Options &options, Operation *module, StringRef entryPoint, |
182 | CompileAndExecuteConfig config, void **args, |
183 | std::unique_ptr<llvm::TargetMachine> tm = nullptr) { |
184 | std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel; |
185 | if (auto clOptLevel = getCommandLineOptLevel(options)) |
186 | jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel); |
187 | |
188 | SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(), |
189 | options.clSharedLibs.end()); |
190 | |
191 | mlir::ExecutionEngineOptions engineOptions; |
192 | engineOptions.llvmModuleBuilder = config.llvmModuleBuilder; |
193 | if (config.transformer) |
194 | engineOptions.transformer = config.transformer; |
195 | engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel; |
196 | engineOptions.sharedLibPaths = sharedLibs; |
197 | engineOptions.enableObjectDump = true; |
198 | auto expectedEngine = |
199 | mlir::ExecutionEngine::create(op: module, options: engineOptions, tm: std::move(tm)); |
200 | if (!expectedEngine) |
201 | return expectedEngine.takeError(); |
202 | |
203 | auto engine = std::move(*expectedEngine); |
204 | |
205 | auto expectedFPtr = engine->lookupPacked(name: entryPoint); |
206 | if (!expectedFPtr) |
207 | return expectedFPtr.takeError(); |
208 | |
209 | if (options.dumpObjectFile) |
210 | engine->dumpToObjectFile(filename: options.objectFilename.empty() |
211 | ? options.inputFilename + ".o" |
212 | : options.objectFilename); |
213 | |
214 | void (*fptr)(void **) = *expectedFPtr; |
215 | (*fptr)(args); |
216 | |
217 | return Error::success(); |
218 | } |
219 | |
220 | static Error compileAndExecuteVoidFunction( |
221 | Options &options, Operation *module, StringRef entryPoint, |
222 | CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) { |
223 | auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>( |
224 | SymbolTable::lookupSymbolIn(module, entryPoint)); |
225 | if (!mainFunction || mainFunction.isExternal()) |
226 | return makeStringError(message: "entry point not found" ); |
227 | |
228 | if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
229 | .getNumParams() != 0) |
230 | return makeStringError( |
231 | message: "JIT can't invoke a main function expecting arguments" ); |
232 | |
233 | auto resultType = dyn_cast<LLVM::LLVMVoidType>( |
234 | mainFunction.getFunctionType().getReturnType()); |
235 | if (!resultType) |
236 | return makeStringError(message: "expected void function" ); |
237 | |
238 | void *empty = nullptr; |
239 | return compileAndExecute(options, module, entryPoint, config: std::move(config), |
240 | args: &empty, tm: std::move(tm)); |
241 | } |
242 | |
243 | template <typename Type> |
244 | Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); |
245 | template <> |
246 | Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { |
247 | auto resultType = dyn_cast<IntegerType>( |
248 | cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
249 | .getReturnType()); |
250 | if (!resultType || resultType.getWidth() != 32) |
251 | return makeStringError(message: "only single i32 function result supported" ); |
252 | return Error::success(); |
253 | } |
254 | template <> |
255 | Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { |
256 | auto resultType = dyn_cast<IntegerType>( |
257 | cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
258 | .getReturnType()); |
259 | if (!resultType || resultType.getWidth() != 64) |
260 | return makeStringError(message: "only single i64 function result supported" ); |
261 | return Error::success(); |
262 | } |
263 | template <> |
264 | Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { |
265 | if (!isa<Float32Type>( |
266 | cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
267 | .getReturnType())) |
268 | return makeStringError("only single f32 function result supported" ); |
269 | return Error::success(); |
270 | } |
271 | template <typename Type> |
272 | Error compileAndExecuteSingleReturnFunction( |
273 | Options &options, Operation *module, StringRef entryPoint, |
274 | CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) { |
275 | auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>( |
276 | SymbolTable::lookupSymbolIn(module, entryPoint)); |
277 | if (!mainFunction || mainFunction.isExternal()) |
278 | return makeStringError(message: "entry point not found" ); |
279 | |
280 | if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType()) |
281 | .getNumParams() != 0) |
282 | return makeStringError( |
283 | message: "JIT can't invoke a main function expecting arguments" ); |
284 | |
285 | if (Error error = checkCompatibleReturnType<Type>(mainFunction)) |
286 | return error; |
287 | |
288 | Type res; |
289 | struct { |
290 | void *data; |
291 | } data; |
292 | data.data = &res; |
293 | if (auto error = |
294 | compileAndExecute(options, module, entryPoint, config: std::move(config), |
295 | args: (void **)&data, tm: std::move(tm))) |
296 | return error; |
297 | |
298 | // Intentional printing of the output so we can test. |
299 | llvm::outs() << res << '\n'; |
300 | |
301 | return Error::success(); |
302 | } |
303 | |
304 | /// Entry point for all CPU runners. Expects the common argc/argv arguments for |
305 | /// standard C++ main functions. |
306 | int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, |
307 | JitRunnerConfig config) { |
308 | llvm::ExitOnError exitOnErr; |
309 | |
310 | // Create the options struct containing the command line options for the |
311 | // runner. This must come before the command line options are parsed. |
312 | Options options; |
313 | llvm::cl::ParseCommandLineOptions(argc, argv, Overview: "MLIR CPU execution driver\n" ); |
314 | |
315 | if (options.hostSupportsJit) { |
316 | auto j = llvm::orc::LLJITBuilder().create(); |
317 | if (j) |
318 | llvm::outs() << "true\n" ; |
319 | else { |
320 | llvm::outs() << "false\n" ; |
321 | exitOnErr(j.takeError()); |
322 | } |
323 | return 0; |
324 | } |
325 | |
326 | std::optional<unsigned> optLevel = getCommandLineOptLevel(options); |
327 | SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ |
328 | options.optO0, options.optO1, options.optO2, options.optO3}; |
329 | |
330 | MLIRContext context(registry); |
331 | |
332 | auto m = parseMLIRInput(inputFilename: options.inputFilename, insertImplicitModule: !options.noImplicitModule, |
333 | context: &context); |
334 | if (!m) { |
335 | llvm::errs() << "could not parse the input IR\n" ; |
336 | return 1; |
337 | } |
338 | |
339 | JitRunnerOptions runnerOptions{.mainFuncName: options.mainFuncName, .mainFuncType: options.mainFuncType}; |
340 | if (config.mlirTransformer) |
341 | if (failed(Result: config.mlirTransformer(m.get(), runnerOptions))) |
342 | return EXIT_FAILURE; |
343 | |
344 | auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); |
345 | if (!tmBuilderOrError) { |
346 | llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n" ; |
347 | return EXIT_FAILURE; |
348 | } |
349 | |
350 | // Configure TargetMachine builder based on the command line options |
351 | llvm::SubtargetFeatures features; |
352 | if (!options.mAttrs.empty()) { |
353 | for (StringRef attr : options.mAttrs) |
354 | features.AddFeature(String: attr); |
355 | tmBuilderOrError->addFeatures(FeatureVec: features.getFeatures()); |
356 | } |
357 | |
358 | if (!options.mArch.empty()) { |
359 | tmBuilderOrError->getTargetTriple().setArchName(options.mArch); |
360 | } |
361 | |
362 | // Build TargetMachine |
363 | auto tmOrError = tmBuilderOrError->createTargetMachine(); |
364 | |
365 | if (!tmOrError) { |
366 | llvm::errs() << "Failed to create a TargetMachine for the host\n" ; |
367 | exitOnErr(tmOrError.takeError()); |
368 | } |
369 | |
370 | LLVM_DEBUG({ |
371 | llvm::dbgs() << " JITTargetMachineBuilder is " |
372 | << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError, |
373 | "\n" ); |
374 | }); |
375 | |
376 | CompileAndExecuteConfig compileAndExecuteConfig; |
377 | if (optLevel) { |
378 | compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer( |
379 | optLevel: *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); |
380 | } |
381 | compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder; |
382 | compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap; |
383 | |
384 | // Get the function used to compile and execute the module. |
385 | using CompileAndExecuteFnT = |
386 | Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig, |
387 | std::unique_ptr<llvm::TargetMachine> tm); |
388 | auto compileAndExecuteFn = |
389 | StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) |
390 | .Case(S: "i32" , Value: compileAndExecuteSingleReturnFunction<int32_t>) |
391 | .Case(S: "i64" , Value: compileAndExecuteSingleReturnFunction<int64_t>) |
392 | .Case(S: "f32" , Value: compileAndExecuteSingleReturnFunction<float>) |
393 | .Case(S: "void" , Value: compileAndExecuteVoidFunction) |
394 | .Default(Value: nullptr); |
395 | |
396 | Error error = compileAndExecuteFn |
397 | ? compileAndExecuteFn( |
398 | options, m.get(), options.mainFuncName.getValue(), |
399 | compileAndExecuteConfig, std::move(tmOrError.get())) |
400 | : makeStringError(message: "unsupported function type" ); |
401 | |
402 | int exitCode = EXIT_SUCCESS; |
403 | llvm::handleAllErrors(E: std::move(error), |
404 | Handlers: [&exitCode](const llvm::ErrorInfoBase &info) { |
405 | llvm::errs() << "Error: " ; |
406 | info.log(OS&: llvm::errs()); |
407 | llvm::errs() << '\n'; |
408 | exitCode = EXIT_FAILURE; |
409 | }); |
410 | |
411 | return exitCode; |
412 | } |
413 | |