1//===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR CUF dialect and LLVM IR.
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
14#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
15#include "flang/Runtime/entry-names.h"
16#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
17#include "mlir/Target/LLVMIR/ModuleTranslation.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/Module.h"
21#include "llvm/Support/FormatVariadic.h"
22
23using namespace mlir;
24
25namespace {
26
27LogicalResult registerModule(cuf::RegisterModuleOp op,
28 llvm::IRBuilderBase &builder,
29 LLVM::ModuleTranslation &moduleTranslation) {
30 std::string binaryIdentifier =
31 op.getName().getLeafReference().str() + "_binary";
32 llvm::Module *module = moduleTranslation.getLLVMModule();
33 llvm::Value *binary = module->getGlobalVariable(Name: binaryIdentifier, AllowInternal: true);
34 if (!binary)
35 return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
36
37 llvm::Type *ptrTy = builder.getPtrTy(AddrSpace: 0);
38 llvm::FunctionCallee fct = module->getOrInsertFunction(
39 RTNAME_STRING(CUFRegisterModule),
40 llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
41 auto *handle = builder.CreateCall(Callee: fct, Args: {binary});
42 moduleTranslation.mapValue(op->getResults().front()) = handle;
43 return mlir::success();
44}
45
46llvm::Value *getOrCreateFunctionName(llvm::Module *module,
47 llvm::IRBuilderBase &builder,
48 llvm::StringRef moduleName,
49 llvm::StringRef kernelName) {
50 std::string globalName =
51 std::string(llvm::formatv(Fmt: "{0}_{1}_kernel_name", Vals&: moduleName, Vals&: kernelName));
52
53 if (llvm::GlobalVariable *gv = module->getGlobalVariable(Name: globalName))
54 return gv;
55
56 return builder.CreateGlobalString(Str: kernelName, Name: globalName);
57}
58
59LogicalResult registerKernel(cuf::RegisterKernelOp op,
60 llvm::IRBuilderBase &builder,
61 LLVM::ModuleTranslation &moduleTranslation) {
62 llvm::Module *module = moduleTranslation.getLLVMModule();
63 llvm::Type *ptrTy = builder.getPtrTy(AddrSpace: 0);
64 llvm::FunctionCallee fct = module->getOrInsertFunction(
65 RTNAME_STRING(CUFRegisterFunction),
66 llvm::FunctionType::get(
67 ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
68 llvm::Value *modulePtr = moduleTranslation.lookupValue(value: op.getModulePtr());
69 if (!modulePtr)
70 return op.emitError() << "Couldn't find the module ptr";
71 llvm::Function *fctSym =
72 moduleTranslation.lookupFunction(name: op.getKernelName().str());
73 if (!fctSym)
74 return op.emitError() << "Couldn't find kernel name symbol: "
75 << op.getKernelName().str();
76 builder.CreateCall(fct, {modulePtr, fctSym,
77 getOrCreateFunctionName(
78 module, builder, op.getKernelModuleName().str(),
79 op.getKernelName().str())});
80 return mlir::success();
81}
82
83class CUFDialectLLVMIRTranslationInterface
84 : public LLVMTranslationDialectInterface {
85public:
86 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
87
88 LogicalResult
89 convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
90 LLVM::ModuleTranslation &moduleTranslation) const override {
91 return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
92 .Case(caseFn: [&](cuf::RegisterModuleOp op) {
93 return registerModule(op, builder, moduleTranslation);
94 })
95 .Case(caseFn: [&](cuf::RegisterKernelOp op) {
96 return registerKernel(op, builder, moduleTranslation);
97 })
98 .Default(defaultFn: [&](Operation *op) {
99 return op->emitError(message: "unsupported GPU operation: ") << op->getName();
100 });
101 }
102};
103
104} // namespace
105
106void cuf::registerCUFDialectTranslation(DialectRegistry &registry) {
107 registry.insert<cuf::CUFDialect>();
108 registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
109 dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
110 });
111}
112

source code of flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp