1 | //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the execution engine for MLIR modules based on LLVM Orc |
10 | // JIT engine. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | #include "mlir/ExecutionEngine/ExecutionEngine.h" |
14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | #include "mlir/IR/BuiltinOps.h" |
16 | #include "mlir/Support/FileUtilities.h" |
17 | #include "mlir/Target/LLVMIR/Export.h" |
18 | |
19 | #include "llvm/ExecutionEngine/JITEventListener.h" |
20 | #include "llvm/ExecutionEngine/ObjectCache.h" |
21 | #include "llvm/ExecutionEngine/Orc/CompileUtils.h" |
22 | #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" |
23 | #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" |
24 | #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" |
25 | #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" |
26 | #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" |
27 | #include "llvm/IR/IRBuilder.h" |
28 | #include "llvm/MC/TargetRegistry.h" |
29 | #include "llvm/Support/Debug.h" |
30 | #include "llvm/Support/Error.h" |
31 | #include "llvm/Support/ToolOutputFile.h" |
32 | #include "llvm/TargetParser/Host.h" |
33 | #include "llvm/TargetParser/SubtargetFeature.h" |
34 | |
35 | #define DEBUG_TYPE "execution-engine" |
36 | |
37 | using namespace mlir; |
38 | using llvm::dbgs; |
39 | using llvm::Error; |
40 | using llvm::errs; |
41 | using llvm::Expected; |
42 | using llvm::LLVMContext; |
43 | using llvm::MemoryBuffer; |
44 | using llvm::MemoryBufferRef; |
45 | using llvm::Module; |
46 | using llvm::SectionMemoryManager; |
47 | using llvm::StringError; |
48 | using llvm::Triple; |
49 | using llvm::orc::DynamicLibrarySearchGenerator; |
50 | using llvm::orc::ExecutionSession; |
51 | using llvm::orc::IRCompileLayer; |
52 | using llvm::orc::JITTargetMachineBuilder; |
53 | using llvm::orc::MangleAndInterner; |
54 | using llvm::orc::RTDyldObjectLinkingLayer; |
55 | using llvm::orc::SymbolMap; |
56 | using llvm::orc::ThreadSafeModule; |
57 | using llvm::orc::TMOwningSimpleCompiler; |
58 | |
59 | /// Wrap a string into an llvm::StringError. |
60 | static Error makeStringError(const Twine &message) { |
61 | return llvm::make_error<StringError>(Args: message.str(), |
62 | Args: llvm::inconvertibleErrorCode()); |
63 | } |
64 | |
65 | void SimpleObjectCache::notifyObjectCompiled(const Module *m, |
66 | MemoryBufferRef objBuffer) { |
67 | cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy( |
68 | InputData: objBuffer.getBuffer(), BufferName: objBuffer.getBufferIdentifier()); |
69 | } |
70 | |
71 | std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) { |
72 | auto i = cachedObjects.find(Key: m->getModuleIdentifier()); |
73 | if (i == cachedObjects.end()) { |
74 | LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier() |
75 | << " in cache. Compiling.\n" ); |
76 | return nullptr; |
77 | } |
78 | LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier() |
79 | << " loaded from cache.\n" ); |
80 | return MemoryBuffer::getMemBuffer(Ref: i->second->getMemBufferRef()); |
81 | } |
82 | |
83 | void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) { |
84 | // Set up the output file. |
85 | std::string errorMessage; |
86 | auto file = openOutputFile(outputFilename, errorMessage: &errorMessage); |
87 | if (!file) { |
88 | llvm::errs() << errorMessage << "\n" ; |
89 | return; |
90 | } |
91 | |
92 | // Dump the object generated for a single module to the output file. |
93 | assert(cachedObjects.size() == 1 && "Expected only one object entry." ); |
94 | auto &cachedObject = cachedObjects.begin()->second; |
95 | file->os() << cachedObject->getBuffer(); |
96 | file->keep(); |
97 | } |
98 | |
99 | bool SimpleObjectCache::isEmpty() { return cachedObjects.empty(); } |
100 | |
101 | void ExecutionEngine::dumpToObjectFile(StringRef filename) { |
102 | if (cache == nullptr) { |
103 | llvm::errs() << "cannot dump ExecutionEngine object code to file: " |
104 | "object cache is disabled\n" ; |
105 | return; |
106 | } |
107 | // Compilation is lazy and it doesn't populate object cache unless requested. |
108 | // In case object dump is requested before cache is populated, we need to |
109 | // force compilation manually. |
110 | if (cache->isEmpty()) { |
111 | for (std::string &functionName : functionNames) { |
112 | auto result = lookupPacked(name: functionName); |
113 | if (!result) { |
114 | llvm::errs() << "Could not compile " << functionName << ":\n " |
115 | << result.takeError() << "\n" ; |
116 | return; |
117 | } |
118 | } |
119 | } |
120 | cache->dumpToObjectFile(outputFilename: filename); |
121 | } |
122 | |
123 | void ExecutionEngine::registerSymbols( |
124 | llvm::function_ref<SymbolMap(MangleAndInterner)> symbolMap) { |
125 | auto &mainJitDylib = jit->getMainJITDylib(); |
126 | cantFail(Err: mainJitDylib.define( |
127 | MU: absoluteSymbols(Symbols: symbolMap(llvm::orc::MangleAndInterner( |
128 | mainJitDylib.getExecutionSession(), jit->getDataLayout()))))); |
129 | } |
130 | |
131 | void ExecutionEngine::setupTargetTripleAndDataLayout(Module *llvmModule, |
132 | llvm::TargetMachine *tm) { |
133 | llvmModule->setDataLayout(tm->createDataLayout()); |
134 | llvmModule->setTargetTriple(tm->getTargetTriple().getTriple()); |
135 | } |
136 | |
137 | static std::string makePackedFunctionName(StringRef name) { |
138 | return "_mlir_" + name.str(); |
139 | } |
140 | |
141 | // For each function in the LLVM module, define an interface function that wraps |
142 | // all the arguments of the original function and all its results into an i8** |
143 | // pointer to provide a unified invocation interface. |
144 | static void packFunctionArguments(Module *module) { |
145 | auto &ctx = module->getContext(); |
146 | llvm::IRBuilder<> builder(ctx); |
147 | DenseSet<llvm::Function *> interfaceFunctions; |
148 | for (auto &func : module->getFunctionList()) { |
149 | if (func.isDeclaration()) { |
150 | continue; |
151 | } |
152 | if (interfaceFunctions.count(V: &func)) { |
153 | continue; |
154 | } |
155 | |
156 | // Given a function `foo(<...>)`, define the interface function |
157 | // `mlir_foo(i8**)`. |
158 | auto *newType = |
159 | llvm::FunctionType::get(Result: builder.getVoidTy(), Params: builder.getPtrTy(), |
160 | /*isVarArg=*/false); |
161 | auto newName = makePackedFunctionName(name: func.getName()); |
162 | auto funcCst = module->getOrInsertFunction(Name: newName, T: newType); |
163 | llvm::Function *interfaceFunc = cast<llvm::Function>(Val: funcCst.getCallee()); |
164 | interfaceFunctions.insert(V: interfaceFunc); |
165 | |
166 | // Extract the arguments from the type-erased argument list and cast them to |
167 | // the proper types. |
168 | auto *bb = llvm::BasicBlock::Create(Context&: ctx); |
169 | bb->insertInto(Parent: interfaceFunc); |
170 | builder.SetInsertPoint(bb); |
171 | llvm::Value *argList = interfaceFunc->arg_begin(); |
172 | SmallVector<llvm::Value *, 8> args; |
173 | args.reserve(N: llvm::size(Range: func.args())); |
174 | for (auto [index, arg] : llvm::enumerate(First: func.args())) { |
175 | llvm::Value *argIndex = llvm::Constant::getIntegerValue( |
176 | Ty: builder.getInt64Ty(), V: APInt(64, index)); |
177 | llvm::Value *argPtrPtr = |
178 | builder.CreateGEP(Ty: builder.getPtrTy(), Ptr: argList, IdxList: argIndex); |
179 | llvm::Value *argPtr = builder.CreateLoad(Ty: builder.getPtrTy(), Ptr: argPtrPtr); |
180 | llvm::Type *argTy = arg.getType(); |
181 | llvm::Value *load = builder.CreateLoad(Ty: argTy, Ptr: argPtr); |
182 | args.push_back(Elt: load); |
183 | } |
184 | |
185 | // Call the implementation function with the extracted arguments. |
186 | llvm::Value *result = builder.CreateCall(Callee: &func, Args: args); |
187 | |
188 | // Assuming the result is one value, potentially of type `void`. |
189 | if (!result->getType()->isVoidTy()) { |
190 | llvm::Value *retIndex = llvm::Constant::getIntegerValue( |
191 | Ty: builder.getInt64Ty(), V: APInt(64, llvm::size(Range: func.args()))); |
192 | llvm::Value *retPtrPtr = |
193 | builder.CreateGEP(Ty: builder.getPtrTy(), Ptr: argList, IdxList: retIndex); |
194 | llvm::Value *retPtr = builder.CreateLoad(Ty: builder.getPtrTy(), Ptr: retPtrPtr); |
195 | builder.CreateStore(Val: result, Ptr: retPtr); |
196 | } |
197 | |
198 | // The interface function returns void. |
199 | builder.CreateRetVoid(); |
200 | } |
201 | } |
202 | |
203 | ExecutionEngine::ExecutionEngine(bool enableObjectDump, |
204 | bool enableGDBNotificationListener, |
205 | bool enablePerfNotificationListener) |
206 | : cache(enableObjectDump ? new SimpleObjectCache() : nullptr), |
207 | functionNames(), |
208 | gdbListener(enableGDBNotificationListener |
209 | ? llvm::JITEventListener::createGDBRegistrationListener() |
210 | : nullptr), |
211 | perfListener(nullptr) { |
212 | if (enablePerfNotificationListener) { |
213 | if (auto *listener = llvm::JITEventListener::createPerfJITEventListener()) |
214 | perfListener = listener; |
215 | else if (auto *listener = |
216 | llvm::JITEventListener::createIntelJITEventListener()) |
217 | perfListener = listener; |
218 | } |
219 | } |
220 | |
221 | ExecutionEngine::~ExecutionEngine() { |
222 | // Execute the global destructors from the module being processed. |
223 | // TODO: Allow JIT deinitialize for AArch64. Currently there's a bug causing a |
224 | // crash for AArch64 see related issue #71963. |
225 | if (jit && !jit->getTargetTriple().isAArch64()) |
226 | llvm::consumeError(Err: jit->deinitialize(JD&: jit->getMainJITDylib())); |
227 | // Run all dynamic library destroy callbacks to prepare for the shutdown. |
228 | for (LibraryDestroyFn destroy : destroyFns) |
229 | destroy(); |
230 | } |
231 | |
232 | Expected<std::unique_ptr<ExecutionEngine>> |
233 | ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options, |
234 | std::unique_ptr<llvm::TargetMachine> tm) { |
235 | auto engine = std::make_unique<ExecutionEngine>( |
236 | args: options.enableObjectDump, args: options.enableGDBNotificationListener, |
237 | args: options.enablePerfNotificationListener); |
238 | |
239 | // Remember all entry-points if object dumping is enabled. |
240 | if (options.enableObjectDump) { |
241 | for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) { |
242 | StringRef funcName = funcOp.getSymName(); |
243 | engine->functionNames.push_back(funcName.str()); |
244 | } |
245 | } |
246 | |
247 | std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext); |
248 | auto llvmModule = options.llvmModuleBuilder |
249 | ? options.llvmModuleBuilder(m, *ctx) |
250 | : translateModuleToLLVMIR(module: m, llvmContext&: *ctx); |
251 | if (!llvmModule) |
252 | return makeStringError(message: "could not convert to LLVM IR" ); |
253 | |
254 | // If no valid TargetMachine was passed, create a default TM ignoring any |
255 | // input arguments from the user. |
256 | if (!tm) { |
257 | auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); |
258 | if (!tmBuilderOrError) |
259 | return tmBuilderOrError.takeError(); |
260 | |
261 | auto tmOrError = tmBuilderOrError->createTargetMachine(); |
262 | if (!tmOrError) |
263 | return tmOrError.takeError(); |
264 | tm = std::move(tmOrError.get()); |
265 | } |
266 | |
267 | // TODO: Currently, the LLVM module created above has no triple associated |
268 | // with it. Instead, the triple is extracted from the TargetMachine, which is |
269 | // either based on the host defaults or command line arguments when specified |
270 | // (set-up by callers of this method). It could also be passed to the |
271 | // translation or dialect conversion instead of this. |
272 | setupTargetTripleAndDataLayout(llvmModule: llvmModule.get(), tm: tm.get()); |
273 | packFunctionArguments(module: llvmModule.get()); |
274 | |
275 | auto dataLayout = llvmModule->getDataLayout(); |
276 | |
277 | // Use absolute library path so that gdb can find the symbol table. |
278 | SmallVector<SmallString<256>, 4> sharedLibPaths; |
279 | transform( |
280 | Range: options.sharedLibPaths, d_first: std::back_inserter(x&: sharedLibPaths), |
281 | F: [](StringRef libPath) { |
282 | SmallString<256> absPath(libPath.begin(), libPath.end()); |
283 | cantFail(Err: llvm::errorCodeToError(EC: llvm::sys::fs::make_absolute(path&: absPath))); |
284 | return absPath; |
285 | }); |
286 | |
287 | // If shared library implements custom execution layer library init and |
288 | // destroy functions, we'll use them to register the library. Otherwise, load |
289 | // the library as JITDyLib below. |
290 | llvm::StringMap<void *> exportSymbols; |
291 | SmallVector<LibraryDestroyFn> destroyFns; |
292 | SmallVector<StringRef> jitDyLibPaths; |
293 | |
294 | for (auto &libPath : sharedLibPaths) { |
295 | auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary( |
296 | filename: libPath.str().str().c_str()); |
297 | void *initSym = lib.getAddressOfSymbol(symbolName: kLibraryInitFnName); |
298 | void *destroySim = lib.getAddressOfSymbol(symbolName: kLibraryDestroyFnName); |
299 | |
300 | // Library does not provide call backs, rely on symbol visiblity. |
301 | if (!initSym || !destroySim) { |
302 | jitDyLibPaths.push_back(Elt: libPath); |
303 | continue; |
304 | } |
305 | |
306 | auto initFn = reinterpret_cast<LibraryInitFn>(initSym); |
307 | initFn(exportSymbols); |
308 | |
309 | auto destroyFn = reinterpret_cast<LibraryDestroyFn>(destroySim); |
310 | destroyFns.push_back(Elt: destroyFn); |
311 | } |
312 | engine->destroyFns = std::move(destroyFns); |
313 | |
314 | // Callback to create the object layer with symbol resolution to current |
315 | // process and dynamically linked libraries. |
316 | auto objectLinkingLayerCreator = [&](ExecutionSession &session, |
317 | const Triple &tt) { |
318 | auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>( |
319 | args&: session, args: [sectionMemoryMapper = options.sectionMemoryMapper]() { |
320 | return std::make_unique<SectionMemoryManager>(args: sectionMemoryMapper); |
321 | }); |
322 | |
323 | // Register JIT event listeners if they are enabled. |
324 | if (engine->gdbListener) |
325 | objectLayer->registerJITEventListener(L&: *engine->gdbListener); |
326 | if (engine->perfListener) |
327 | objectLayer->registerJITEventListener(L&: *engine->perfListener); |
328 | |
329 | // COFF format binaries (Windows) need special handling to deal with |
330 | // exported symbol visibility. |
331 | // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer |
332 | llvm::Triple targetTriple(llvm::Twine(llvmModule->getTargetTriple())); |
333 | if (targetTriple.isOSBinFormatCOFF()) { |
334 | objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); |
335 | objectLayer->setAutoClaimResponsibilityForObjectSymbols(true); |
336 | } |
337 | |
338 | // Resolve symbols from shared libraries. |
339 | for (auto &libPath : jitDyLibPaths) { |
340 | auto mb = llvm::MemoryBuffer::getFile(Filename: libPath); |
341 | if (!mb) { |
342 | errs() << "Failed to create MemoryBuffer for: " << libPath |
343 | << "\nError: " << mb.getError().message() << "\n" ; |
344 | continue; |
345 | } |
346 | auto &jd = session.createBareJITDylib(Name: std::string(libPath)); |
347 | auto loaded = DynamicLibrarySearchGenerator::Load( |
348 | FileName: libPath.str().c_str(), GlobalPrefix: dataLayout.getGlobalPrefix()); |
349 | if (!loaded) { |
350 | errs() << "Could not load " << libPath << ":\n " << loaded.takeError() |
351 | << "\n" ; |
352 | continue; |
353 | } |
354 | jd.addGenerator(DefGenerator: std::move(*loaded)); |
355 | cantFail(Err: objectLayer->add(JD&: jd, O: std::move(mb.get()))); |
356 | } |
357 | |
358 | return objectLayer; |
359 | }; |
360 | |
361 | // Callback to inspect the cache and recompile on demand. This follows Lang's |
362 | // LLJITWithObjectCache example. |
363 | auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb) |
364 | -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> { |
365 | if (options.jitCodeGenOptLevel) |
366 | jtmb.setCodeGenOptLevel(*options.jitCodeGenOptLevel); |
367 | return std::make_unique<TMOwningSimpleCompiler>(args: std::move(tm), |
368 | args: engine->cache.get()); |
369 | }; |
370 | |
371 | // Create the LLJIT by calling the LLJITBuilder with 2 callbacks. |
372 | auto jit = |
373 | cantFail(ValOrErr: llvm::orc::LLJITBuilder() |
374 | .setCompileFunctionCreator(compileFunctionCreator) |
375 | .setObjectLinkingLayerCreator(objectLinkingLayerCreator) |
376 | .setDataLayout(dataLayout) |
377 | .create()); |
378 | |
379 | // Add a ThreadSafemodule to the engine and return. |
380 | ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx)); |
381 | if (options.transformer) |
382 | cantFail(Err: tsm.withModuleDo( |
383 | F: [&](llvm::Module &module) { return options.transformer(&module); })); |
384 | cantFail(Err: jit->addIRModule(TSM: std::move(tsm))); |
385 | engine->jit = std::move(jit); |
386 | |
387 | // Resolve symbols that are statically linked in the current process. |
388 | llvm::orc::JITDylib &mainJD = engine->jit->getMainJITDylib(); |
389 | mainJD.addGenerator( |
390 | DefGenerator: cantFail(ValOrErr: DynamicLibrarySearchGenerator::GetForCurrentProcess( |
391 | GlobalPrefix: dataLayout.getGlobalPrefix()))); |
392 | |
393 | // Build a runtime symbol map from the exported symbols and register them. |
394 | auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) { |
395 | auto symbolMap = llvm::orc::SymbolMap(); |
396 | for (auto &exportSymbol : exportSymbols) |
397 | symbolMap[interner(exportSymbol.getKey())] = { |
398 | llvm::orc::ExecutorAddr::fromPtr(Ptr: exportSymbol.getValue()), |
399 | llvm::JITSymbolFlags::Exported}; |
400 | return symbolMap; |
401 | }; |
402 | engine->registerSymbols(symbolMap: runtimeSymbolMap); |
403 | |
404 | // Execute the global constructors from the module being processed. |
405 | // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a |
406 | // crash for AArch64 see related issue #71963. |
407 | if (!engine->jit->getTargetTriple().isAArch64()) |
408 | cantFail(Err: engine->jit->initialize(JD&: engine->jit->getMainJITDylib())); |
409 | |
410 | return std::move(engine); |
411 | } |
412 | |
413 | Expected<void (*)(void **)> |
414 | ExecutionEngine::lookupPacked(StringRef name) const { |
415 | auto result = lookup(name: makePackedFunctionName(name)); |
416 | if (!result) |
417 | return result.takeError(); |
418 | return reinterpret_cast<void (*)(void **)>(result.get()); |
419 | } |
420 | |
421 | Expected<void *> ExecutionEngine::lookup(StringRef name) const { |
422 | auto expectedSymbol = jit->lookup(UnmangledName: name); |
423 | |
424 | // JIT lookup may return an Error referring to strings stored internally by |
425 | // the JIT. If the Error outlives the ExecutionEngine, it would want have a |
426 | // dangling reference, which is currently caught by an assertion inside JIT |
427 | // thanks to hand-rolled reference counting. Rewrap the error message into a |
428 | // string before returning. Alternatively, ORC JIT should consider copying |
429 | // the string into the error message. |
430 | if (!expectedSymbol) { |
431 | std::string errorMessage; |
432 | llvm::raw_string_ostream os(errorMessage); |
433 | llvm::handleAllErrors(E: expectedSymbol.takeError(), |
434 | Handlers: [&os](llvm::ErrorInfoBase &ei) { ei.log(OS&: os); }); |
435 | return makeStringError(message: os.str()); |
436 | } |
437 | |
438 | if (void *fptr = expectedSymbol->toPtr<void *>()) |
439 | return fptr; |
440 | return makeStringError(message: "looked up function is null" ); |
441 | } |
442 | |
443 | Error ExecutionEngine::invokePacked(StringRef name, |
444 | MutableArrayRef<void *> args) { |
445 | auto expectedFPtr = lookupPacked(name); |
446 | if (!expectedFPtr) |
447 | return expectedFPtr.takeError(); |
448 | auto fptr = *expectedFPtr; |
449 | |
450 | (*fptr)(args.data()); |
451 | |
452 | return Error::success(); |
453 | } |
454 | |