1//===- ROCDLToLLVMIRTranslation.cpp - Translate ROCDL to LLVM IR ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR ROCDL dialect and
10// LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
15#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Target/LLVMIR/ModuleTranslation.h"
19
20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/IntrinsicsAMDGPU.h"
22#include "llvm/IR/MDBuilder.h"
23#include "llvm/Support/raw_ostream.h"
24
25using namespace mlir;
26using namespace mlir::LLVM;
27using mlir::LLVM::detail::createIntrinsicCall;
28
29static llvm::Value *createIntrinsicCallWithRange(llvm::IRBuilderBase &builder,
30 llvm::Intrinsic::ID intrinsic,
31 DenseI32ArrayAttr maybeRange) {
32 auto *inst = llvm::cast<llvm::CallInst>(
33 Val: createIntrinsicCall(builder, intrinsic, args: {}, tys: {}));
34 if (maybeRange) {
35 SmallVector<llvm::APInt, 2> apInts;
36 for (int32_t i : maybeRange.asArrayRef())
37 apInts.push_back(llvm::APInt(32, i));
38 llvm::MDBuilder mdBuilder(builder.getContext());
39 llvm::MDNode *range = mdBuilder.createRange(Lo: apInts[0], Hi: apInts[1]);
40 inst->setMetadata(KindID: llvm::LLVMContext::MD_range, Node: range);
41 }
42 return inst;
43}
44
45// Create a call to ROCm-Device-Library function
46// Currently this routine will work only for calling ROCDL functions that
47// take a single int32 argument. It is likely that the interface of this
48// function will change to make it more generic.
49static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
50 StringRef fnName, int parameter) {
51 llvm::Module *module = builder.GetInsertBlock()->getModule();
52 llvm::FunctionType *functionType = llvm::FunctionType::get(
53 Result: llvm::Type::getInt64Ty(C&: module->getContext()), // return type.
54 Params: llvm::Type::getInt32Ty(C&: module->getContext()), // parameter type.
55 isVarArg: false); // no variadic arguments.
56 llvm::Function *fn = dyn_cast<llvm::Function>(
57 Val: module->getOrInsertFunction(Name: fnName, T: functionType).getCallee());
58 llvm::Value *fnOp0 = llvm::ConstantInt::get(
59 Ty: llvm::Type::getInt32Ty(C&: module->getContext()), V: parameter);
60 return builder.CreateCall(Callee: fn, Args: ArrayRef<llvm::Value *>(fnOp0));
61}
62
63namespace {
64/// Implementation of the dialect interface that converts operations belonging
65/// to the ROCDL dialect to LLVM IR.
66class ROCDLDialectLLVMIRTranslationInterface
67 : public LLVMTranslationDialectInterface {
68public:
69 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
70
71 /// Translates the given operation to LLVM IR using the provided IR builder
72 /// and saving the state in `moduleTranslation`.
73 LogicalResult
74 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
75 LLVM::ModuleTranslation &moduleTranslation) const final {
76 Operation &opInst = *op;
77#include "mlir/Dialect/LLVMIR/ROCDLConversions.inc"
78
79 return failure();
80 }
81
82 /// Attaches module-level metadata for functions marked as kernels.
83 LogicalResult
84 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
85 NamedAttribute attribute,
86 LLVM::ModuleTranslation &moduleTranslation) const final {
87 auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
88 if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
89 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
90 if (!func)
91 return op->emitOpError(message: Twine(attribute.getName()) +
92 " is only supported on `llvm.func` operations");
93 ;
94
95 // For GPU kernels,
96 // 1. Insert AMDGPU_KERNEL calling convention.
97 // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user
98 // has overriden this value - 256 is the default in clang
99 llvm::Function *llvmFunc =
100 moduleTranslation.lookupFunction(name: func.getName());
101 llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
102 if (!llvmFunc->hasFnAttribute(Kind: "amdgpu-flat-work-group-size")) {
103 llvmFunc->addFnAttr(Kind: "amdgpu-flat-work-group-size", Val: "1,256");
104 }
105
106 // MLIR's GPU kernel APIs all assume and produce uniformly-sized
107 // workgroups, so the lowering of the `rocdl.kernel` marker encodes this
108 // assumption. This assumption may be overridden by setting
109 // `rocdl.uniform_work_group_size` on a given function.
110 if (!llvmFunc->hasFnAttribute(Kind: "uniform-work-group-size"))
111 llvmFunc->addFnAttr(Kind: "uniform-work-group-size", Val: "true");
112 }
113 // Override flat-work-group-size
114 // TODO: update clients to rocdl.flat_work_group_size instead,
115 // then remove this half of the branch
116 if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
117 attribute.getName()) {
118 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
119 if (!func)
120 return op->emitOpError(message: Twine(attribute.getName()) +
121 " is only supported on `llvm.func` operations");
122 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
123 if (!value)
124 return op->emitOpError(message: Twine(attribute.getName()) +
125 " must be an integer");
126
127 llvm::Function *llvmFunc =
128 moduleTranslation.lookupFunction(name: func.getName());
129 llvm::SmallString<8> llvmAttrValue;
130 llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
131 attrValueStream << "1," << value.getInt();
132 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
133 }
134 if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
135 attribute.getName()) {
136 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
137 if (!func)
138 return op->emitOpError(message: Twine(attribute.getName()) +
139 " is only supported on `llvm.func` operations");
140 auto value = dyn_cast<StringAttr>(attribute.getValue());
141 if (!value)
142 return op->emitOpError(message: Twine(attribute.getName()) +
143 " must be a string");
144
145 llvm::Function *llvmFunc =
146 moduleTranslation.lookupFunction(name: func.getName());
147 llvm::SmallString<8> llvmAttrValue;
148 llvmAttrValue.append(value.getValue());
149 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
150 }
151 if (ROCDL::ROCDLDialect::getUniformWorkGroupSizeAttrName() ==
152 attribute.getName()) {
153 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
154 if (!func)
155 return op->emitOpError(message: Twine(attribute.getName()) +
156 " is only supported on `llvm.func` operations");
157 auto value = dyn_cast<BoolAttr>(attribute.getValue());
158 if (!value)
159 return op->emitOpError(message: Twine(attribute.getName()) +
160 " must be a boolean");
161 llvm::Function *llvmFunc =
162 moduleTranslation.lookupFunction(name: func.getName());
163 llvmFunc->addFnAttr("uniform-work-group-size",
164 value.getValue() ? "true" : "false");
165 }
166 // Set reqd_work_group_size metadata
167 if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
168 attribute.getName()) {
169 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
170 if (!func)
171 return op->emitOpError(message: Twine(attribute.getName()) +
172 " is only supported on `llvm.func` operations");
173 auto value = dyn_cast<DenseI32ArrayAttr>(attribute.getValue());
174 if (!value)
175 return op->emitOpError(message: Twine(attribute.getName()) +
176 " must be a dense i32 array attribute");
177 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
178 SmallVector<llvm::Metadata *, 3> metadata;
179 llvm::Type *i32 = llvm::IntegerType::get(C&: llvmContext, NumBits: 32);
180 for (int32_t i : value.asArrayRef()) {
181 llvm::Constant *constant = llvm::ConstantInt::get(i32, i);
182 metadata.push_back(llvm::ConstantAsMetadata::get(constant));
183 }
184 llvm::Function *llvmFunc =
185 moduleTranslation.lookupFunction(name: func.getName());
186 llvm::MDNode *node = llvm::MDNode::get(Context&: llvmContext, MDs: metadata);
187 llvmFunc->setMetadata(Kind: "reqd_work_group_size", Node: node);
188 }
189 return success();
190 }
191};
192} // namespace
193
194void mlir::registerROCDLDialectTranslation(DialectRegistry &registry) {
195 registry.insert<ROCDL::ROCDLDialect>();
196 registry.addExtension(extensionFn: +[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) {
197 dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>();
198 });
199}
200
201void mlir::registerROCDLDialectTranslation(MLIRContext &context) {
202 DialectRegistry registry;
203 registerROCDLDialectTranslation(registry);
204 context.appendDialectRegistry(registry);
205}
206

source code of mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp