1//===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is a library that provides a shared implementation for command line
10// utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11// IR before JIT-compiling and executing the latter.
12//
13// The translation can be customized by providing an MLIR to MLIR
14// transformation.
15//===----------------------------------------------------------------------===//
16
17#include "mlir/ExecutionEngine/JitRunner.h"
18
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/ExecutionEngine/ExecutionEngine.h"
21#include "mlir/ExecutionEngine/OptUtils.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/MLIRContext.h"
24#include "mlir/Parser/Parser.h"
25#include "mlir/Support/FileUtilities.h"
26#include "mlir/Tools/ParseUtilities.h"
27
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30#include "llvm/ExecutionEngine/Orc/LLJIT.h"
31#include "llvm/IR/IRBuilder.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/IR/LegacyPassNameParser.h"
34#include "llvm/Support/CommandLine.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/FileUtilities.h"
37#include "llvm/Support/SourceMgr.h"
38#include "llvm/Support/StringSaver.h"
39#include "llvm/Support/ToolOutputFile.h"
40#include <cstdint>
41#include <numeric>
42#include <optional>
43#include <utility>
44
45#define DEBUG_TYPE "jit-runner"
46
47using namespace mlir;
48using llvm::Error;
49
50namespace {
51/// This options struct prevents the need for global static initializers, and
52/// is only initialized if the JITRunner is invoked.
53struct Options {
54 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
55 llvm::cl::desc("<input file>"),
56 llvm::cl::init(Val: "-")};
57 llvm::cl::opt<std::string> mainFuncName{
58 "e", llvm::cl::desc("The function to be called"),
59 llvm::cl::value_desc("<function name>"), llvm::cl::init(Val: "main")};
60 llvm::cl::opt<std::string> mainFuncType{
61 "entry-point-result",
62 llvm::cl::desc("Textual description of the function type to be called"),
63 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init(Val: "f32")};
64
65 llvm::cl::OptionCategory optFlags{"opt-like flags"};
66
67 // CLI variables for -On options.
68 llvm::cl::opt<bool> optO0{"O0",
69 llvm::cl::desc("Run opt passes and codegen at O0"),
70 llvm::cl::cat(optFlags)};
71 llvm::cl::opt<bool> optO1{"O1",
72 llvm::cl::desc("Run opt passes and codegen at O1"),
73 llvm::cl::cat(optFlags)};
74 llvm::cl::opt<bool> optO2{"O2",
75 llvm::cl::desc("Run opt passes and codegen at O2"),
76 llvm::cl::cat(optFlags)};
77 llvm::cl::opt<bool> optO3{"O3",
78 llvm::cl::desc("Run opt passes and codegen at O3"),
79 llvm::cl::cat(optFlags)};
80
81 llvm::cl::list<std::string> mAttrs{
82 "mattr", llvm::cl::MiscFlags::CommaSeparated,
83 llvm::cl::desc("Target specific attributes (-mattr=help for details)"),
84 llvm::cl::value_desc("a1,+a2,-a3,..."), llvm::cl::cat(optFlags)};
85
86 llvm::cl::opt<std::string> mArch{
87 "march",
88 llvm::cl::desc("Architecture to generate code for (see --version)")};
89
90 llvm::cl::OptionCategory clOptionsCategory{"linking options"};
91 llvm::cl::list<std::string> clSharedLibs{
92 "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
93 llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
94
95 /// CLI variables for debugging.
96 llvm::cl::opt<bool> dumpObjectFile{
97 "dump-object-file",
98 llvm::cl::desc("Dump JITted-compiled object to file specified with "
99 "-object-filename (<input file>.o by default).")};
100
101 llvm::cl::opt<std::string> objectFilename{
102 "object-filename",
103 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
104
105 llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
106 llvm::cl::desc("Report host JIT support"),
107 llvm::cl::Hidden};
108
109 llvm::cl::opt<bool> noImplicitModule{
110 "no-implicit-module",
111 llvm::cl::desc(
112 "Disable implicit addition of a top-level module op during parsing"),
113 llvm::cl::init(Val: false)};
114};
115
116struct CompileAndExecuteConfig {
117 /// LLVM module transformer that is passed to ExecutionEngine.
118 std::function<llvm::Error(llvm::Module *)> transformer;
119
120 /// A custom function that is passed to ExecutionEngine. It processes MLIR
121 /// module and creates LLVM IR module.
122 llvm::function_ref<std::unique_ptr<llvm::Module>(Operation *,
123 llvm::LLVMContext &)>
124 llvmModuleBuilder;
125
126 /// A custom function that is passed to ExecutinEngine to register symbols at
127 /// runtime.
128 llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
129 runtimeSymbolMap;
130};
131
132} // namespace
133
134static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
135 bool insertImplicitModule,
136 MLIRContext *context) {
137 // Set up the input file.
138 std::string errorMessage;
139 auto file = openInputFile(inputFilename, errorMessage: &errorMessage);
140 if (!file) {
141 llvm::errs() << errorMessage << "\n";
142 return nullptr;
143 }
144
145 auto sourceMgr = std::make_shared<llvm::SourceMgr>();
146 sourceMgr->AddNewSourceBuffer(F: std::move(file), IncludeLoc: SMLoc());
147 OwningOpRef<Operation *> module =
148 parseSourceFileForTool(sourceMgr, config: context, insertImplicitModule);
149 if (!module)
150 return nullptr;
151 if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
152 llvm::errs() << "Error: top-level op must be a symbol table.\n";
153 return nullptr;
154 }
155 return module;
156}
157
158static inline Error makeStringError(const Twine &message) {
159 return llvm::make_error<llvm::StringError>(Args: message.str(),
160 Args: llvm::inconvertibleErrorCode());
161}
162
163static std::optional<unsigned> getCommandLineOptLevel(Options &options) {
164 std::optional<unsigned> optLevel;
165 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
166 options.optO0, options.optO1, options.optO2, options.optO3};
167
168 // Determine if there is an optimization flag present.
169 for (unsigned j = 0; j < 4; ++j) {
170 auto &flag = optFlags[j].get();
171 if (flag) {
172 optLevel = j;
173 break;
174 }
175 }
176 return optLevel;
177}
178
179// JIT-compile the given module and run "entryPoint" with "args" as arguments.
180static Error
181compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
182 CompileAndExecuteConfig config, void **args,
183 std::unique_ptr<llvm::TargetMachine> tm = nullptr) {
184 std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel;
185 if (auto clOptLevel = getCommandLineOptLevel(options))
186 jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel);
187
188 SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(),
189 options.clSharedLibs.end());
190
191 mlir::ExecutionEngineOptions engineOptions;
192 engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
193 if (config.transformer)
194 engineOptions.transformer = config.transformer;
195 engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
196 engineOptions.sharedLibPaths = sharedLibs;
197 engineOptions.enableObjectDump = true;
198 auto expectedEngine =
199 mlir::ExecutionEngine::create(op: module, options: engineOptions, tm: std::move(tm));
200 if (!expectedEngine)
201 return expectedEngine.takeError();
202
203 auto engine = std::move(*expectedEngine);
204
205 auto expectedFPtr = engine->lookupPacked(name: entryPoint);
206 if (!expectedFPtr)
207 return expectedFPtr.takeError();
208
209 if (options.dumpObjectFile)
210 engine->dumpToObjectFile(filename: options.objectFilename.empty()
211 ? options.inputFilename + ".o"
212 : options.objectFilename);
213
214 void (*fptr)(void **) = *expectedFPtr;
215 (*fptr)(args);
216
217 return Error::success();
218}
219
220static Error compileAndExecuteVoidFunction(
221 Options &options, Operation *module, StringRef entryPoint,
222 CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
223 auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
224 SymbolTable::lookupSymbolIn(module, entryPoint));
225 if (!mainFunction || mainFunction.empty())
226 return makeStringError(message: "entry point not found");
227
228 auto resultType = dyn_cast<LLVM::LLVMVoidType>(
229 mainFunction.getFunctionType().getReturnType());
230 if (!resultType)
231 return makeStringError(message: "expected void function");
232
233 void *empty = nullptr;
234 return compileAndExecute(options, module, entryPoint, config: std::move(config),
235 args: &empty, tm: std::move(tm));
236}
237
238template <typename Type>
239Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
240template <>
241Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
242 auto resultType = dyn_cast<IntegerType>(
243 cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
244 .getReturnType());
245 if (!resultType || resultType.getWidth() != 32)
246 return makeStringError(message: "only single i32 function result supported");
247 return Error::success();
248}
249template <>
250Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
251 auto resultType = dyn_cast<IntegerType>(
252 cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
253 .getReturnType());
254 if (!resultType || resultType.getWidth() != 64)
255 return makeStringError(message: "only single i64 function result supported");
256 return Error::success();
257}
258template <>
259Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
260 if (!isa<Float32Type>(
261 cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
262 .getReturnType()))
263 return makeStringError("only single f32 function result supported");
264 return Error::success();
265}
266template <typename Type>
267Error compileAndExecuteSingleReturnFunction(
268 Options &options, Operation *module, StringRef entryPoint,
269 CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
270 auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
271 SymbolTable::lookupSymbolIn(module, entryPoint));
272 if (!mainFunction || mainFunction.isExternal())
273 return makeStringError(message: "entry point not found");
274
275 if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
276 .getNumParams() != 0)
277 return makeStringError(message: "function inputs not supported");
278
279 if (Error error = checkCompatibleReturnType<Type>(mainFunction))
280 return error;
281
282 Type res;
283 struct {
284 void *data;
285 } data;
286 data.data = &res;
287 if (auto error =
288 compileAndExecute(options, module, entryPoint, config: std::move(config),
289 args: (void **)&data, tm: std::move(tm)))
290 return error;
291
292 // Intentional printing of the output so we can test.
293 llvm::outs() << res << '\n';
294
295 return Error::success();
296}
297
298/// Entry point for all CPU runners. Expects the common argc/argv arguments for
299/// standard C++ main functions.
300int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
301 JitRunnerConfig config) {
302 llvm::ExitOnError exitOnErr;
303
304 // Create the options struct containing the command line options for the
305 // runner. This must come before the command line options are parsed.
306 Options options;
307 llvm::cl::ParseCommandLineOptions(argc, argv, Overview: "MLIR CPU execution driver\n");
308
309 if (options.hostSupportsJit) {
310 auto j = llvm::orc::LLJITBuilder().create();
311 if (j)
312 llvm::outs() << "true\n";
313 else {
314 llvm::outs() << "false\n";
315 exitOnErr(j.takeError());
316 }
317 return 0;
318 }
319
320 std::optional<unsigned> optLevel = getCommandLineOptLevel(options);
321 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
322 options.optO0, options.optO1, options.optO2, options.optO3};
323
324 MLIRContext context(registry);
325
326 auto m = parseMLIRInput(inputFilename: options.inputFilename, insertImplicitModule: !options.noImplicitModule,
327 context: &context);
328 if (!m) {
329 llvm::errs() << "could not parse the input IR\n";
330 return 1;
331 }
332
333 JitRunnerOptions runnerOptions{.mainFuncName: options.mainFuncName, .mainFuncType: options.mainFuncType};
334 if (config.mlirTransformer)
335 if (failed(result: config.mlirTransformer(m.get(), runnerOptions)))
336 return EXIT_FAILURE;
337
338 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
339 if (!tmBuilderOrError) {
340 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
341 return EXIT_FAILURE;
342 }
343
344 // Configure TargetMachine builder based on the command line options
345 llvm::SubtargetFeatures features;
346 if (!options.mAttrs.empty()) {
347 for (StringRef attr : options.mAttrs)
348 features.AddFeature(String: attr);
349 tmBuilderOrError->addFeatures(FeatureVec: features.getFeatures());
350 }
351
352 if (!options.mArch.empty()) {
353 tmBuilderOrError->getTargetTriple().setArchName(options.mArch);
354 }
355
356 // Build TargetMachine
357 auto tmOrError = tmBuilderOrError->createTargetMachine();
358
359 if (!tmOrError) {
360 llvm::errs() << "Failed to create a TargetMachine for the host\n";
361 exitOnErr(tmOrError.takeError());
362 }
363
364 LLVM_DEBUG({
365 llvm::dbgs() << " JITTargetMachineBuilder is "
366 << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError,
367 "\n");
368 });
369
370 CompileAndExecuteConfig compileAndExecuteConfig;
371 if (optLevel) {
372 compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
373 optLevel: *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
374 }
375 compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
376 compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
377
378 // Get the function used to compile and execute the module.
379 using CompileAndExecuteFnT =
380 Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig,
381 std::unique_ptr<llvm::TargetMachine> tm);
382 auto compileAndExecuteFn =
383 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
384 .Case(S: "i32", Value: compileAndExecuteSingleReturnFunction<int32_t>)
385 .Case(S: "i64", Value: compileAndExecuteSingleReturnFunction<int64_t>)
386 .Case(S: "f32", Value: compileAndExecuteSingleReturnFunction<float>)
387 .Case(S: "void", Value: compileAndExecuteVoidFunction)
388 .Default(Value: nullptr);
389
390 Error error = compileAndExecuteFn
391 ? compileAndExecuteFn(
392 options, m.get(), options.mainFuncName.getValue(),
393 compileAndExecuteConfig, std::move(tmOrError.get()))
394 : makeStringError(message: "unsupported function type");
395
396 int exitCode = EXIT_SUCCESS;
397 llvm::handleAllErrors(E: std::move(error),
398 Handlers: [&exitCode](const llvm::ErrorInfoBase &info) {
399 llvm::errs() << "Error: ";
400 info.log(OS&: llvm::errs());
401 llvm::errs() << '\n';
402 exitCode = EXIT_FAILURE;
403 });
404
405 return exitCode;
406}
407

source code of mlir/lib/ExecutionEngine/JitRunner.cpp