1//===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the execution engine for MLIR modules based on LLVM Orc
10// JIT engine.
11//
12//===----------------------------------------------------------------------===//
13#include "mlir/ExecutionEngine/ExecutionEngine.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/IR/BuiltinOps.h"
16#include "mlir/Support/FileUtilities.h"
17#include "mlir/Target/LLVMIR/Export.h"
18
19#include "llvm/ExecutionEngine/JITEventListener.h"
20#include "llvm/ExecutionEngine/ObjectCache.h"
21#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
22#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
23#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
26#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/MC/TargetRegistry.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/Error.h"
31#include "llvm/Support/ToolOutputFile.h"
32#include "llvm/TargetParser/Host.h"
33#include "llvm/TargetParser/SubtargetFeature.h"
34
35#define DEBUG_TYPE "execution-engine"
36
37using namespace mlir;
38using llvm::dbgs;
39using llvm::Error;
40using llvm::errs;
41using llvm::Expected;
42using llvm::LLVMContext;
43using llvm::MemoryBuffer;
44using llvm::MemoryBufferRef;
45using llvm::Module;
46using llvm::SectionMemoryManager;
47using llvm::StringError;
48using llvm::Triple;
49using llvm::orc::DynamicLibrarySearchGenerator;
50using llvm::orc::ExecutionSession;
51using llvm::orc::IRCompileLayer;
52using llvm::orc::JITTargetMachineBuilder;
53using llvm::orc::MangleAndInterner;
54using llvm::orc::RTDyldObjectLinkingLayer;
55using llvm::orc::SymbolMap;
56using llvm::orc::ThreadSafeModule;
57using llvm::orc::TMOwningSimpleCompiler;
58
59/// Wrap a string into an llvm::StringError.
60static Error makeStringError(const Twine &message) {
61 return llvm::make_error<StringError>(Args: message.str(),
62 Args: llvm::inconvertibleErrorCode());
63}
64
65void SimpleObjectCache::notifyObjectCompiled(const Module *m,
66 MemoryBufferRef objBuffer) {
67 cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
68 InputData: objBuffer.getBuffer(), BufferName: objBuffer.getBufferIdentifier());
69}
70
71std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
72 auto i = cachedObjects.find(Key: m->getModuleIdentifier());
73 if (i == cachedObjects.end()) {
74 LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
75 << " in cache. Compiling.\n");
76 return nullptr;
77 }
78 LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
79 << " loaded from cache.\n");
80 return MemoryBuffer::getMemBuffer(Ref: i->second->getMemBufferRef());
81}
82
83void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
84 // Set up the output file.
85 std::string errorMessage;
86 auto file = openOutputFile(outputFilename, errorMessage: &errorMessage);
87 if (!file) {
88 llvm::errs() << errorMessage << "\n";
89 return;
90 }
91
92 // Dump the object generated for a single module to the output file.
93 assert(cachedObjects.size() == 1 && "Expected only one object entry.");
94 auto &cachedObject = cachedObjects.begin()->second;
95 file->os() << cachedObject->getBuffer();
96 file->keep();
97}
98
99bool SimpleObjectCache::isEmpty() { return cachedObjects.empty(); }
100
101void ExecutionEngine::dumpToObjectFile(StringRef filename) {
102 if (cache == nullptr) {
103 llvm::errs() << "cannot dump ExecutionEngine object code to file: "
104 "object cache is disabled\n";
105 return;
106 }
107 // Compilation is lazy and it doesn't populate object cache unless requested.
108 // In case object dump is requested before cache is populated, we need to
109 // force compilation manually.
110 if (cache->isEmpty()) {
111 for (std::string &functionName : functionNames) {
112 auto result = lookupPacked(name: functionName);
113 if (!result) {
114 llvm::errs() << "Could not compile " << functionName << ":\n "
115 << result.takeError() << "\n";
116 return;
117 }
118 }
119 }
120 cache->dumpToObjectFile(outputFilename: filename);
121}
122
123void ExecutionEngine::registerSymbols(
124 llvm::function_ref<SymbolMap(MangleAndInterner)> symbolMap) {
125 auto &mainJitDylib = jit->getMainJITDylib();
126 cantFail(Err: mainJitDylib.define(
127 MU: absoluteSymbols(Symbols: symbolMap(llvm::orc::MangleAndInterner(
128 mainJitDylib.getExecutionSession(), jit->getDataLayout())))));
129}
130
131void ExecutionEngine::setupTargetTripleAndDataLayout(Module *llvmModule,
132 llvm::TargetMachine *tm) {
133 llvmModule->setDataLayout(tm->createDataLayout());
134 llvmModule->setTargetTriple(tm->getTargetTriple().getTriple());
135}
136
137static std::string makePackedFunctionName(StringRef name) {
138 return "_mlir_" + name.str();
139}
140
141// For each function in the LLVM module, define an interface function that wraps
142// all the arguments of the original function and all its results into an i8**
143// pointer to provide a unified invocation interface.
144static void packFunctionArguments(Module *module) {
145 auto &ctx = module->getContext();
146 llvm::IRBuilder<> builder(ctx);
147 DenseSet<llvm::Function *> interfaceFunctions;
148 for (auto &func : module->getFunctionList()) {
149 if (func.isDeclaration()) {
150 continue;
151 }
152 if (interfaceFunctions.count(V: &func)) {
153 continue;
154 }
155
156 // Given a function `foo(<...>)`, define the interface function
157 // `mlir_foo(i8**)`.
158 auto *newType =
159 llvm::FunctionType::get(Result: builder.getVoidTy(), Params: builder.getPtrTy(),
160 /*isVarArg=*/false);
161 auto newName = makePackedFunctionName(name: func.getName());
162 auto funcCst = module->getOrInsertFunction(Name: newName, T: newType);
163 llvm::Function *interfaceFunc = cast<llvm::Function>(Val: funcCst.getCallee());
164 interfaceFunctions.insert(V: interfaceFunc);
165
166 // Extract the arguments from the type-erased argument list and cast them to
167 // the proper types.
168 auto *bb = llvm::BasicBlock::Create(Context&: ctx);
169 bb->insertInto(Parent: interfaceFunc);
170 builder.SetInsertPoint(bb);
171 llvm::Value *argList = interfaceFunc->arg_begin();
172 SmallVector<llvm::Value *, 8> args;
173 args.reserve(N: llvm::size(Range: func.args()));
174 for (auto [index, arg] : llvm::enumerate(First: func.args())) {
175 llvm::Value *argIndex = llvm::Constant::getIntegerValue(
176 Ty: builder.getInt64Ty(), V: APInt(64, index));
177 llvm::Value *argPtrPtr =
178 builder.CreateGEP(Ty: builder.getPtrTy(), Ptr: argList, IdxList: argIndex);
179 llvm::Value *argPtr = builder.CreateLoad(Ty: builder.getPtrTy(), Ptr: argPtrPtr);
180 llvm::Type *argTy = arg.getType();
181 llvm::Value *load = builder.CreateLoad(Ty: argTy, Ptr: argPtr);
182 args.push_back(Elt: load);
183 }
184
185 // Call the implementation function with the extracted arguments.
186 llvm::Value *result = builder.CreateCall(Callee: &func, Args: args);
187
188 // Assuming the result is one value, potentially of type `void`.
189 if (!result->getType()->isVoidTy()) {
190 llvm::Value *retIndex = llvm::Constant::getIntegerValue(
191 Ty: builder.getInt64Ty(), V: APInt(64, llvm::size(Range: func.args())));
192 llvm::Value *retPtrPtr =
193 builder.CreateGEP(Ty: builder.getPtrTy(), Ptr: argList, IdxList: retIndex);
194 llvm::Value *retPtr = builder.CreateLoad(Ty: builder.getPtrTy(), Ptr: retPtrPtr);
195 builder.CreateStore(Val: result, Ptr: retPtr);
196 }
197
198 // The interface function returns void.
199 builder.CreateRetVoid();
200 }
201}
202
203ExecutionEngine::ExecutionEngine(bool enableObjectDump,
204 bool enableGDBNotificationListener,
205 bool enablePerfNotificationListener)
206 : cache(enableObjectDump ? new SimpleObjectCache() : nullptr),
207 functionNames(),
208 gdbListener(enableGDBNotificationListener
209 ? llvm::JITEventListener::createGDBRegistrationListener()
210 : nullptr),
211 perfListener(nullptr) {
212 if (enablePerfNotificationListener) {
213 if (auto *listener = llvm::JITEventListener::createPerfJITEventListener())
214 perfListener = listener;
215 else if (auto *listener =
216 llvm::JITEventListener::createIntelJITEventListener())
217 perfListener = listener;
218 }
219}
220
221ExecutionEngine::~ExecutionEngine() {
222 // Execute the global destructors from the module being processed.
223 // TODO: Allow JIT deinitialize for AArch64. Currently there's a bug causing a
224 // crash for AArch64 see related issue #71963.
225 if (jit && !jit->getTargetTriple().isAArch64())
226 llvm::consumeError(Err: jit->deinitialize(JD&: jit->getMainJITDylib()));
227 // Run all dynamic library destroy callbacks to prepare for the shutdown.
228 for (LibraryDestroyFn destroy : destroyFns)
229 destroy();
230}
231
232Expected<std::unique_ptr<ExecutionEngine>>
233ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
234 std::unique_ptr<llvm::TargetMachine> tm) {
235 auto engine = std::make_unique<ExecutionEngine>(
236 args: options.enableObjectDump, args: options.enableGDBNotificationListener,
237 args: options.enablePerfNotificationListener);
238
239 // Remember all entry-points if object dumping is enabled.
240 if (options.enableObjectDump) {
241 for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
242 StringRef funcName = funcOp.getSymName();
243 engine->functionNames.push_back(funcName.str());
244 }
245 }
246
247 std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
248 auto llvmModule = options.llvmModuleBuilder
249 ? options.llvmModuleBuilder(m, *ctx)
250 : translateModuleToLLVMIR(module: m, llvmContext&: *ctx);
251 if (!llvmModule)
252 return makeStringError(message: "could not convert to LLVM IR");
253
254 // If no valid TargetMachine was passed, create a default TM ignoring any
255 // input arguments from the user.
256 if (!tm) {
257 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
258 if (!tmBuilderOrError)
259 return tmBuilderOrError.takeError();
260
261 auto tmOrError = tmBuilderOrError->createTargetMachine();
262 if (!tmOrError)
263 return tmOrError.takeError();
264 tm = std::move(tmOrError.get());
265 }
266
267 // TODO: Currently, the LLVM module created above has no triple associated
268 // with it. Instead, the triple is extracted from the TargetMachine, which is
269 // either based on the host defaults or command line arguments when specified
270 // (set-up by callers of this method). It could also be passed to the
271 // translation or dialect conversion instead of this.
272 setupTargetTripleAndDataLayout(llvmModule: llvmModule.get(), tm: tm.get());
273 packFunctionArguments(module: llvmModule.get());
274
275 auto dataLayout = llvmModule->getDataLayout();
276
277 // Use absolute library path so that gdb can find the symbol table.
278 SmallVector<SmallString<256>, 4> sharedLibPaths;
279 transform(
280 Range: options.sharedLibPaths, d_first: std::back_inserter(x&: sharedLibPaths),
281 F: [](StringRef libPath) {
282 SmallString<256> absPath(libPath.begin(), libPath.end());
283 cantFail(Err: llvm::errorCodeToError(EC: llvm::sys::fs::make_absolute(path&: absPath)));
284 return absPath;
285 });
286
287 // If shared library implements custom execution layer library init and
288 // destroy functions, we'll use them to register the library. Otherwise, load
289 // the library as JITDyLib below.
290 llvm::StringMap<void *> exportSymbols;
291 SmallVector<LibraryDestroyFn> destroyFns;
292 SmallVector<StringRef> jitDyLibPaths;
293
294 for (auto &libPath : sharedLibPaths) {
295 auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(
296 filename: libPath.str().str().c_str());
297 void *initSym = lib.getAddressOfSymbol(symbolName: kLibraryInitFnName);
298 void *destroySim = lib.getAddressOfSymbol(symbolName: kLibraryDestroyFnName);
299
300 // Library does not provide call backs, rely on symbol visiblity.
301 if (!initSym || !destroySim) {
302 jitDyLibPaths.push_back(Elt: libPath);
303 continue;
304 }
305
306 auto initFn = reinterpret_cast<LibraryInitFn>(initSym);
307 initFn(exportSymbols);
308
309 auto destroyFn = reinterpret_cast<LibraryDestroyFn>(destroySim);
310 destroyFns.push_back(Elt: destroyFn);
311 }
312 engine->destroyFns = std::move(destroyFns);
313
314 // Callback to create the object layer with symbol resolution to current
315 // process and dynamically linked libraries.
316 auto objectLinkingLayerCreator = [&](ExecutionSession &session,
317 const Triple &tt) {
318 auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
319 args&: session, args: [sectionMemoryMapper = options.sectionMemoryMapper]() {
320 return std::make_unique<SectionMemoryManager>(args: sectionMemoryMapper);
321 });
322
323 // Register JIT event listeners if they are enabled.
324 if (engine->gdbListener)
325 objectLayer->registerJITEventListener(L&: *engine->gdbListener);
326 if (engine->perfListener)
327 objectLayer->registerJITEventListener(L&: *engine->perfListener);
328
329 // COFF format binaries (Windows) need special handling to deal with
330 // exported symbol visibility.
331 // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
332 llvm::Triple targetTriple(llvm::Twine(llvmModule->getTargetTriple()));
333 if (targetTriple.isOSBinFormatCOFF()) {
334 objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
335 objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
336 }
337
338 // Resolve symbols from shared libraries.
339 for (auto &libPath : jitDyLibPaths) {
340 auto mb = llvm::MemoryBuffer::getFile(Filename: libPath);
341 if (!mb) {
342 errs() << "Failed to create MemoryBuffer for: " << libPath
343 << "\nError: " << mb.getError().message() << "\n";
344 continue;
345 }
346 auto &jd = session.createBareJITDylib(Name: std::string(libPath));
347 auto loaded = DynamicLibrarySearchGenerator::Load(
348 FileName: libPath.str().c_str(), GlobalPrefix: dataLayout.getGlobalPrefix());
349 if (!loaded) {
350 errs() << "Could not load " << libPath << ":\n " << loaded.takeError()
351 << "\n";
352 continue;
353 }
354 jd.addGenerator(DefGenerator: std::move(*loaded));
355 cantFail(Err: objectLayer->add(JD&: jd, O: std::move(mb.get())));
356 }
357
358 return objectLayer;
359 };
360
361 // Callback to inspect the cache and recompile on demand. This follows Lang's
362 // LLJITWithObjectCache example.
363 auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
364 -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
365 if (options.jitCodeGenOptLevel)
366 jtmb.setCodeGenOptLevel(*options.jitCodeGenOptLevel);
367 return std::make_unique<TMOwningSimpleCompiler>(args: std::move(tm),
368 args: engine->cache.get());
369 };
370
371 // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
372 auto jit =
373 cantFail(ValOrErr: llvm::orc::LLJITBuilder()
374 .setCompileFunctionCreator(compileFunctionCreator)
375 .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
376 .setDataLayout(dataLayout)
377 .create());
378
379 // Add a ThreadSafemodule to the engine and return.
380 ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx));
381 if (options.transformer)
382 cantFail(Err: tsm.withModuleDo(
383 F: [&](llvm::Module &module) { return options.transformer(&module); }));
384 cantFail(Err: jit->addIRModule(TSM: std::move(tsm)));
385 engine->jit = std::move(jit);
386
387 // Resolve symbols that are statically linked in the current process.
388 llvm::orc::JITDylib &mainJD = engine->jit->getMainJITDylib();
389 mainJD.addGenerator(
390 DefGenerator: cantFail(ValOrErr: DynamicLibrarySearchGenerator::GetForCurrentProcess(
391 GlobalPrefix: dataLayout.getGlobalPrefix())));
392
393 // Build a runtime symbol map from the exported symbols and register them.
394 auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
395 auto symbolMap = llvm::orc::SymbolMap();
396 for (auto &exportSymbol : exportSymbols)
397 symbolMap[interner(exportSymbol.getKey())] = {
398 llvm::orc::ExecutorAddr::fromPtr(Ptr: exportSymbol.getValue()),
399 llvm::JITSymbolFlags::Exported};
400 return symbolMap;
401 };
402 engine->registerSymbols(symbolMap: runtimeSymbolMap);
403
404 // Execute the global constructors from the module being processed.
405 // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
406 // crash for AArch64 see related issue #71963.
407 if (!engine->jit->getTargetTriple().isAArch64())
408 cantFail(Err: engine->jit->initialize(JD&: engine->jit->getMainJITDylib()));
409
410 return std::move(engine);
411}
412
413Expected<void (*)(void **)>
414ExecutionEngine::lookupPacked(StringRef name) const {
415 auto result = lookup(name: makePackedFunctionName(name));
416 if (!result)
417 return result.takeError();
418 return reinterpret_cast<void (*)(void **)>(result.get());
419}
420
421Expected<void *> ExecutionEngine::lookup(StringRef name) const {
422 auto expectedSymbol = jit->lookup(UnmangledName: name);
423
424 // JIT lookup may return an Error referring to strings stored internally by
425 // the JIT. If the Error outlives the ExecutionEngine, it would want have a
426 // dangling reference, which is currently caught by an assertion inside JIT
427 // thanks to hand-rolled reference counting. Rewrap the error message into a
428 // string before returning. Alternatively, ORC JIT should consider copying
429 // the string into the error message.
430 if (!expectedSymbol) {
431 std::string errorMessage;
432 llvm::raw_string_ostream os(errorMessage);
433 llvm::handleAllErrors(E: expectedSymbol.takeError(),
434 Handlers: [&os](llvm::ErrorInfoBase &ei) { ei.log(OS&: os); });
435 return makeStringError(message: os.str());
436 }
437
438 if (void *fptr = expectedSymbol->toPtr<void *>())
439 return fptr;
440 return makeStringError(message: "looked up function is null");
441}
442
443Error ExecutionEngine::invokePacked(StringRef name,
444 MutableArrayRef<void *> args) {
445 auto expectedFPtr = lookupPacked(name);
446 if (!expectedFPtr)
447 return expectedFPtr.takeError();
448 auto fptr = *expectedFPtr;
449
450 (*fptr)(args.data());
451
452 return Error::success();
453}
454

source code of mlir/lib/ExecutionEngine/ExecutionEngine.cpp