1 | //===- ROCDLToLLVMIRTranslation.cpp - Translate ROCDL to LLVM IR ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a translation between the MLIR ROCDL dialect and |
10 | // LLVM IR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" |
15 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
16 | #include "mlir/IR/BuiltinAttributes.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
19 | |
20 | #include "llvm/IR/IRBuilder.h" |
21 | #include "llvm/IR/IntrinsicsAMDGPU.h" |
22 | #include "llvm/IR/MDBuilder.h" |
23 | #include "llvm/Support/raw_ostream.h" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::LLVM; |
27 | using mlir::LLVM::detail::createIntrinsicCall; |
28 | |
29 | static llvm::Value *createIntrinsicCallWithRange(llvm::IRBuilderBase &builder, |
30 | llvm::Intrinsic::ID intrinsic, |
31 | DenseI32ArrayAttr maybeRange) { |
32 | auto *inst = llvm::cast<llvm::CallInst>( |
33 | Val: createIntrinsicCall(builder, intrinsic, args: {}, tys: {})); |
34 | if (maybeRange) { |
35 | SmallVector<llvm::APInt, 2> apInts; |
36 | for (int32_t i : maybeRange.asArrayRef()) |
37 | apInts.push_back(llvm::APInt(32, i)); |
38 | llvm::MDBuilder mdBuilder(builder.getContext()); |
39 | llvm::MDNode *range = mdBuilder.createRange(Lo: apInts[0], Hi: apInts[1]); |
40 | inst->setMetadata(KindID: llvm::LLVMContext::MD_range, Node: range); |
41 | } |
42 | return inst; |
43 | } |
44 | |
45 | // Create a call to ROCm-Device-Library function |
46 | // Currently this routine will work only for calling ROCDL functions that |
47 | // take a single int32 argument. It is likely that the interface of this |
48 | // function will change to make it more generic. |
49 | static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder, |
50 | StringRef fnName, int parameter) { |
51 | llvm::Module *module = builder.GetInsertBlock()->getModule(); |
52 | llvm::FunctionType *functionType = llvm::FunctionType::get( |
53 | Result: llvm::Type::getInt64Ty(C&: module->getContext()), // return type. |
54 | Params: llvm::Type::getInt32Ty(C&: module->getContext()), // parameter type. |
55 | isVarArg: false); // no variadic arguments. |
56 | llvm::Function *fn = dyn_cast<llvm::Function>( |
57 | Val: module->getOrInsertFunction(Name: fnName, T: functionType).getCallee()); |
58 | llvm::Value *fnOp0 = llvm::ConstantInt::get( |
59 | Ty: llvm::Type::getInt32Ty(C&: module->getContext()), V: parameter); |
60 | return builder.CreateCall(Callee: fn, Args: ArrayRef<llvm::Value *>(fnOp0)); |
61 | } |
62 | |
63 | namespace { |
64 | /// Implementation of the dialect interface that converts operations belonging |
65 | /// to the ROCDL dialect to LLVM IR. |
66 | class ROCDLDialectLLVMIRTranslationInterface |
67 | : public LLVMTranslationDialectInterface { |
68 | public: |
69 | using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; |
70 | |
71 | /// Translates the given operation to LLVM IR using the provided IR builder |
72 | /// and saving the state in `moduleTranslation`. |
73 | LogicalResult |
74 | convertOperation(Operation *op, llvm::IRBuilderBase &builder, |
75 | LLVM::ModuleTranslation &moduleTranslation) const final { |
76 | Operation &opInst = *op; |
77 | #include "mlir/Dialect/LLVMIR/ROCDLConversions.inc" |
78 | |
79 | return failure(); |
80 | } |
81 | |
82 | /// Attaches module-level metadata for functions marked as kernels. |
83 | LogicalResult |
84 | amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, |
85 | NamedAttribute attribute, |
86 | LLVM::ModuleTranslation &moduleTranslation) const final { |
87 | auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect()); |
88 | if (dialect->getKernelAttrHelper().getName() == attribute.getName()) { |
89 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
90 | if (!func) |
91 | return op->emitOpError(message: Twine(attribute.getName()) + |
92 | " is only supported on `llvm.func` operations" ); |
93 | ; |
94 | |
95 | // For GPU kernels, |
96 | // 1. Insert AMDGPU_KERNEL calling convention. |
97 | // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user |
98 | // has overriden this value - 256 is the default in clang |
99 | llvm::Function *llvmFunc = |
100 | moduleTranslation.lookupFunction(name: func.getName()); |
101 | llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); |
102 | if (!llvmFunc->hasFnAttribute(Kind: "amdgpu-flat-work-group-size" )) { |
103 | llvmFunc->addFnAttr(Kind: "amdgpu-flat-work-group-size" , Val: "1,256" ); |
104 | } |
105 | |
106 | // MLIR's GPU kernel APIs all assume and produce uniformly-sized |
107 | // workgroups, so the lowering of the `rocdl.kernel` marker encodes this |
108 | // assumption. This assumption may be overridden by setting |
109 | // `rocdl.uniform_work_group_size` on a given function. |
110 | if (!llvmFunc->hasFnAttribute(Kind: "uniform-work-group-size" )) |
111 | llvmFunc->addFnAttr(Kind: "uniform-work-group-size" , Val: "true" ); |
112 | } |
113 | // Override flat-work-group-size |
114 | // TODO: update clients to rocdl.flat_work_group_size instead, |
115 | // then remove this half of the branch |
116 | if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() == |
117 | attribute.getName()) { |
118 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
119 | if (!func) |
120 | return op->emitOpError(message: Twine(attribute.getName()) + |
121 | " is only supported on `llvm.func` operations" ); |
122 | auto value = dyn_cast<IntegerAttr>(attribute.getValue()); |
123 | if (!value) |
124 | return op->emitOpError(message: Twine(attribute.getName()) + |
125 | " must be an integer" ); |
126 | |
127 | llvm::Function *llvmFunc = |
128 | moduleTranslation.lookupFunction(name: func.getName()); |
129 | llvm::SmallString<8> llvmAttrValue; |
130 | llvm::raw_svector_ostream attrValueStream(llvmAttrValue); |
131 | attrValueStream << "1," << value.getInt(); |
132 | llvmFunc->addFnAttr("amdgpu-flat-work-group-size" , llvmAttrValue); |
133 | } |
134 | if (dialect->getFlatWorkGroupSizeAttrHelper().getName() == |
135 | attribute.getName()) { |
136 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
137 | if (!func) |
138 | return op->emitOpError(message: Twine(attribute.getName()) + |
139 | " is only supported on `llvm.func` operations" ); |
140 | auto value = dyn_cast<StringAttr>(attribute.getValue()); |
141 | if (!value) |
142 | return op->emitOpError(message: Twine(attribute.getName()) + |
143 | " must be a string" ); |
144 | |
145 | llvm::Function *llvmFunc = |
146 | moduleTranslation.lookupFunction(name: func.getName()); |
147 | llvm::SmallString<8> llvmAttrValue; |
148 | llvmAttrValue.append(value.getValue()); |
149 | llvmFunc->addFnAttr("amdgpu-flat-work-group-size" , llvmAttrValue); |
150 | } |
151 | if (ROCDL::ROCDLDialect::getUniformWorkGroupSizeAttrName() == |
152 | attribute.getName()) { |
153 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
154 | if (!func) |
155 | return op->emitOpError(message: Twine(attribute.getName()) + |
156 | " is only supported on `llvm.func` operations" ); |
157 | auto value = dyn_cast<BoolAttr>(attribute.getValue()); |
158 | if (!value) |
159 | return op->emitOpError(message: Twine(attribute.getName()) + |
160 | " must be a boolean" ); |
161 | llvm::Function *llvmFunc = |
162 | moduleTranslation.lookupFunction(name: func.getName()); |
163 | llvmFunc->addFnAttr("uniform-work-group-size" , |
164 | value.getValue() ? "true" : "false" ); |
165 | } |
166 | // Set reqd_work_group_size metadata |
167 | if (dialect->getReqdWorkGroupSizeAttrHelper().getName() == |
168 | attribute.getName()) { |
169 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
170 | if (!func) |
171 | return op->emitOpError(message: Twine(attribute.getName()) + |
172 | " is only supported on `llvm.func` operations" ); |
173 | auto value = dyn_cast<DenseI32ArrayAttr>(attribute.getValue()); |
174 | if (!value) |
175 | return op->emitOpError(message: Twine(attribute.getName()) + |
176 | " must be a dense i32 array attribute" ); |
177 | llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); |
178 | SmallVector<llvm::Metadata *, 3> metadata; |
179 | llvm::Type *i32 = llvm::IntegerType::get(C&: llvmContext, NumBits: 32); |
180 | for (int32_t i : value.asArrayRef()) { |
181 | llvm::Constant *constant = llvm::ConstantInt::get(i32, i); |
182 | metadata.push_back(llvm::ConstantAsMetadata::get(constant)); |
183 | } |
184 | llvm::Function *llvmFunc = |
185 | moduleTranslation.lookupFunction(name: func.getName()); |
186 | llvm::MDNode *node = llvm::MDNode::get(Context&: llvmContext, MDs: metadata); |
187 | llvmFunc->setMetadata(Kind: "reqd_work_group_size" , Node: node); |
188 | } |
189 | return success(); |
190 | } |
191 | }; |
192 | } // namespace |
193 | |
194 | void mlir::registerROCDLDialectTranslation(DialectRegistry ®istry) { |
195 | registry.insert<ROCDL::ROCDLDialect>(); |
196 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) { |
197 | dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>(); |
198 | }); |
199 | } |
200 | |
201 | void mlir::registerROCDLDialectTranslation(MLIRContext &context) { |
202 | DialectRegistry registry; |
203 | registerROCDLDialectTranslation(registry); |
204 | context.appendDialectRegistry(registry); |
205 | } |
206 | |