1 | //===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a translation between the MLIR CUF dialect and LLVM IR. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h" |
14 | #include "flang/Optimizer/Dialect/CUF/CUFOps.h" |
15 | #include "flang/Runtime/entry-names.h" |
16 | #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" |
17 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
18 | #include "llvm/ADT/TypeSwitch.h" |
19 | #include "llvm/IR/IRBuilder.h" |
20 | #include "llvm/IR/Module.h" |
21 | #include "llvm/Support/FormatVariadic.h" |
22 | |
23 | using namespace mlir; |
24 | |
25 | namespace { |
26 | |
27 | LogicalResult registerModule(cuf::RegisterModuleOp op, |
28 | llvm::IRBuilderBase &builder, |
29 | LLVM::ModuleTranslation &moduleTranslation) { |
30 | std::string binaryIdentifier = |
31 | op.getName().getLeafReference().str() + "_binary" ; |
32 | llvm::Module *module = moduleTranslation.getLLVMModule(); |
33 | llvm::Value *binary = module->getGlobalVariable(Name: binaryIdentifier, AllowInternal: true); |
34 | if (!binary) |
35 | return op.emitError() << "Couldn't find the binary: " << binaryIdentifier; |
36 | |
37 | llvm::Type *ptrTy = builder.getPtrTy(AddrSpace: 0); |
38 | llvm::FunctionCallee fct = module->getOrInsertFunction( |
39 | RTNAME_STRING(CUFRegisterModule), |
40 | llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false)); |
41 | auto *handle = builder.CreateCall(Callee: fct, Args: {binary}); |
42 | moduleTranslation.mapValue(op->getResults().front()) = handle; |
43 | return mlir::success(); |
44 | } |
45 | |
46 | llvm::Value *getOrCreateFunctionName(llvm::Module *module, |
47 | llvm::IRBuilderBase &builder, |
48 | llvm::StringRef moduleName, |
49 | llvm::StringRef kernelName) { |
50 | std::string globalName = |
51 | std::string(llvm::formatv(Fmt: "{0}_{1}_kernel_name" , Vals&: moduleName, Vals&: kernelName)); |
52 | |
53 | if (llvm::GlobalVariable *gv = module->getGlobalVariable(Name: globalName)) |
54 | return gv; |
55 | |
56 | return builder.CreateGlobalString(Str: kernelName, Name: globalName); |
57 | } |
58 | |
59 | LogicalResult registerKernel(cuf::RegisterKernelOp op, |
60 | llvm::IRBuilderBase &builder, |
61 | LLVM::ModuleTranslation &moduleTranslation) { |
62 | llvm::Module *module = moduleTranslation.getLLVMModule(); |
63 | llvm::Type *ptrTy = builder.getPtrTy(AddrSpace: 0); |
64 | llvm::FunctionCallee fct = module->getOrInsertFunction( |
65 | RTNAME_STRING(CUFRegisterFunction), |
66 | llvm::FunctionType::get( |
67 | ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false)); |
68 | llvm::Value *modulePtr = moduleTranslation.lookupValue(value: op.getModulePtr()); |
69 | if (!modulePtr) |
70 | return op.emitError() << "Couldn't find the module ptr" ; |
71 | llvm::Function *fctSym = |
72 | moduleTranslation.lookupFunction(name: op.getKernelName().str()); |
73 | if (!fctSym) |
74 | return op.emitError() << "Couldn't find kernel name symbol: " |
75 | << op.getKernelName().str(); |
76 | builder.CreateCall(fct, {modulePtr, fctSym, |
77 | getOrCreateFunctionName( |
78 | module, builder, op.getKernelModuleName().str(), |
79 | op.getKernelName().str())}); |
80 | return mlir::success(); |
81 | } |
82 | |
83 | class CUFDialectLLVMIRTranslationInterface |
84 | : public LLVMTranslationDialectInterface { |
85 | public: |
86 | using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; |
87 | |
88 | LogicalResult |
89 | convertOperation(Operation *operation, llvm::IRBuilderBase &builder, |
90 | LLVM::ModuleTranslation &moduleTranslation) const override { |
91 | return llvm::TypeSwitch<Operation *, LogicalResult>(operation) |
92 | .Case(caseFn: [&](cuf::RegisterModuleOp op) { |
93 | return registerModule(op, builder, moduleTranslation); |
94 | }) |
95 | .Case(caseFn: [&](cuf::RegisterKernelOp op) { |
96 | return registerKernel(op, builder, moduleTranslation); |
97 | }) |
98 | .Default(defaultFn: [&](Operation *op) { |
99 | return op->emitError(message: "unsupported GPU operation: " ) << op->getName(); |
100 | }); |
101 | } |
102 | }; |
103 | |
104 | } // namespace |
105 | |
106 | void cuf::registerCUFDialectTranslation(DialectRegistry ®istry) { |
107 | registry.insert<cuf::CUFDialect>(); |
108 | registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) { |
109 | dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>(); |
110 | }); |
111 | } |
112 | |