1 | //===- ROCDLToLLVMIRTranslation.cpp - Translate ROCDL to LLVM IR ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a translation between the MLIR ROCDL dialect and |
10 | // LLVM IR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" |
15 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
16 | #include "mlir/IR/BuiltinAttributes.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
19 | |
20 | #include "llvm/IR/ConstantRange.h" |
21 | #include "llvm/IR/IRBuilder.h" |
22 | #include "llvm/IR/IntrinsicsAMDGPU.h" |
23 | #include "llvm/Support/raw_ostream.h" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::LLVM; |
27 | using mlir::LLVM::detail::createIntrinsicCall; |
28 | |
29 | // Create a call to ROCm-Device-Library function that returns an ID. |
30 | // This is intended to specifically call device functions that fetch things like |
31 | // block or grid dimensions, and so is limited to functions that take one |
32 | // integer parameter. |
33 | static llvm::Value *createDimGetterFunctionCall(llvm::IRBuilderBase &builder, |
34 | Operation *op, StringRef fnName, |
35 | int parameter) { |
36 | llvm::Module *module = builder.GetInsertBlock()->getModule(); |
37 | llvm::FunctionType *functionType = llvm::FunctionType::get( |
38 | Result: llvm::Type::getInt64Ty(C&: module->getContext()), // return type. |
39 | Params: llvm::Type::getInt32Ty(C&: module->getContext()), // parameter type. |
40 | isVarArg: false); // no variadic arguments. |
41 | llvm::Function *fn = dyn_cast<llvm::Function>( |
42 | Val: module->getOrInsertFunction(Name: fnName, T: functionType).getCallee()); |
43 | llvm::Value *fnOp0 = llvm::ConstantInt::get( |
44 | Ty: llvm::Type::getInt32Ty(C&: module->getContext()), V: parameter); |
45 | auto *call = builder.CreateCall(Callee: fn, Args: ArrayRef<llvm::Value *>(fnOp0)); |
46 | if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range" )) { |
47 | // Zero-extend to 64 bits because the GPU dialect uses 32-bit bounds but |
48 | // these ockl functions are defined to be 64-bits |
49 | call->addRangeRetAttr(CR: llvm::ConstantRange(rangeAttr.getLower().zext(64), |
50 | rangeAttr.getUpper().zext(64))); |
51 | } |
52 | return call; |
53 | } |
54 | |
55 | namespace { |
56 | /// Implementation of the dialect interface that converts operations belonging |
57 | /// to the ROCDL dialect to LLVM IR. |
58 | class ROCDLDialectLLVMIRTranslationInterface |
59 | : public LLVMTranslationDialectInterface { |
60 | public: |
61 | using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; |
62 | |
63 | /// Translates the given operation to LLVM IR using the provided IR builder |
64 | /// and saving the state in `moduleTranslation`. |
65 | LogicalResult |
66 | convertOperation(Operation *op, llvm::IRBuilderBase &builder, |
67 | LLVM::ModuleTranslation &moduleTranslation) const final { |
68 | Operation &opInst = *op; |
69 | #include "mlir/Dialect/LLVMIR/ROCDLConversions.inc" |
70 | |
71 | return failure(); |
72 | } |
73 | |
74 | /// Attaches module-level metadata for functions marked as kernels. |
75 | LogicalResult |
76 | amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, |
77 | NamedAttribute attribute, |
78 | LLVM::ModuleTranslation &moduleTranslation) const final { |
79 | auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect()); |
80 | llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); |
81 | if (dialect->getKernelAttrHelper().getName() == attribute.getName()) { |
82 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
83 | if (!func) |
84 | return op->emitOpError(message: Twine(attribute.getName()) + |
85 | " is only supported on `llvm.func` operations" ); |
86 | ; |
87 | |
88 | // For GPU kernels, |
89 | // 1. Insert AMDGPU_KERNEL calling convention. |
90 | // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user |
91 | // has overriden this value - 256 is the default in clang |
92 | llvm::Function *llvmFunc = |
93 | moduleTranslation.lookupFunction(name: func.getName()); |
94 | llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); |
95 | if (!llvmFunc->hasFnAttribute(Kind: "amdgpu-flat-work-group-size" )) { |
96 | llvmFunc->addFnAttr(Kind: "amdgpu-flat-work-group-size" , Val: "1,256" ); |
97 | } |
98 | |
99 | // MLIR's GPU kernel APIs all assume and produce uniformly-sized |
100 | // workgroups, so the lowering of the `rocdl.kernel` marker encodes this |
101 | // assumption. This assumption may be overridden by setting |
102 | // `rocdl.uniform_work_group_size` on a given function. |
103 | if (!llvmFunc->hasFnAttribute(Kind: "uniform-work-group-size" )) |
104 | llvmFunc->addFnAttr(Kind: "uniform-work-group-size" , Val: "true" ); |
105 | } |
106 | // Override flat-work-group-size |
107 | // TODO: update clients to rocdl.flat_work_group_size instead, |
108 | // then remove this half of the branch |
109 | if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() == |
110 | attribute.getName()) { |
111 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
112 | if (!func) |
113 | return op->emitOpError(message: Twine(attribute.getName()) + |
114 | " is only supported on `llvm.func` operations" ); |
115 | auto value = dyn_cast<IntegerAttr>(attribute.getValue()); |
116 | if (!value) |
117 | return op->emitOpError(message: Twine(attribute.getName()) + |
118 | " must be an integer" ); |
119 | |
120 | llvm::Function *llvmFunc = |
121 | moduleTranslation.lookupFunction(name: func.getName()); |
122 | llvm::SmallString<8> llvmAttrValue; |
123 | llvm::raw_svector_ostream attrValueStream(llvmAttrValue); |
124 | attrValueStream << "1," << value.getInt(); |
125 | llvmFunc->addFnAttr("amdgpu-flat-work-group-size" , llvmAttrValue); |
126 | } |
127 | if (dialect->getWavesPerEuAttrHelper().getName() == attribute.getName()) { |
128 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
129 | if (!func) |
130 | return op->emitOpError(message: Twine(attribute.getName()) + |
131 | " is only supported on `llvm.func` operations" ); |
132 | auto value = dyn_cast<IntegerAttr>(attribute.getValue()); |
133 | if (!value) |
134 | return op->emitOpError(message: Twine(attribute.getName()) + |
135 | " must be an integer" ); |
136 | |
137 | llvm::Function *llvmFunc = |
138 | moduleTranslation.lookupFunction(name: func.getName()); |
139 | llvm::SmallString<8> llvmAttrValue; |
140 | llvm::raw_svector_ostream attrValueStream(llvmAttrValue); |
141 | attrValueStream << value.getInt(); |
142 | llvmFunc->addFnAttr("amdgpu-waves-per-eu" , llvmAttrValue); |
143 | } |
144 | if (dialect->getFlatWorkGroupSizeAttrHelper().getName() == |
145 | attribute.getName()) { |
146 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
147 | if (!func) |
148 | return op->emitOpError(message: Twine(attribute.getName()) + |
149 | " is only supported on `llvm.func` operations" ); |
150 | auto value = dyn_cast<StringAttr>(attribute.getValue()); |
151 | if (!value) |
152 | return op->emitOpError(message: Twine(attribute.getName()) + |
153 | " must be a string" ); |
154 | |
155 | llvm::Function *llvmFunc = |
156 | moduleTranslation.lookupFunction(name: func.getName()); |
157 | llvm::SmallString<8> llvmAttrValue; |
158 | llvmAttrValue.append(value.getValue()); |
159 | llvmFunc->addFnAttr("amdgpu-flat-work-group-size" , llvmAttrValue); |
160 | } |
161 | if (ROCDL::ROCDLDialect::getUniformWorkGroupSizeAttrName() == |
162 | attribute.getName()) { |
163 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
164 | if (!func) |
165 | return op->emitOpError(message: Twine(attribute.getName()) + |
166 | " is only supported on `llvm.func` operations" ); |
167 | auto value = dyn_cast<BoolAttr>(attribute.getValue()); |
168 | if (!value) |
169 | return op->emitOpError(message: Twine(attribute.getName()) + |
170 | " must be a boolean" ); |
171 | llvm::Function *llvmFunc = |
172 | moduleTranslation.lookupFunction(name: func.getName()); |
173 | llvmFunc->addFnAttr("uniform-work-group-size" , |
174 | value.getValue() ? "true" : "false" ); |
175 | } |
176 | if (dialect->getUnsafeFpAtomicsAttrHelper().getName() == |
177 | attribute.getName()) { |
178 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
179 | if (!func) |
180 | return op->emitOpError(message: Twine(attribute.getName()) + |
181 | " is only supported on `llvm.func` operations" ); |
182 | auto value = dyn_cast<BoolAttr>(attribute.getValue()); |
183 | if (!value) |
184 | return op->emitOpError(message: Twine(attribute.getName()) + |
185 | " must be a boolean" ); |
186 | llvm::Function *llvmFunc = |
187 | moduleTranslation.lookupFunction(name: func.getName()); |
188 | llvmFunc->addFnAttr("amdgpu-unsafe-fp-atomics" , |
189 | value.getValue() ? "true" : "false" ); |
190 | } |
191 | // Set reqd_work_group_size metadata |
192 | if (dialect->getReqdWorkGroupSizeAttrHelper().getName() == |
193 | attribute.getName()) { |
194 | auto func = dyn_cast<LLVM::LLVMFuncOp>(op); |
195 | if (!func) |
196 | return op->emitOpError(message: Twine(attribute.getName()) + |
197 | " is only supported on `llvm.func` operations" ); |
198 | auto value = dyn_cast<DenseI32ArrayAttr>(attribute.getValue()); |
199 | if (!value) |
200 | return op->emitOpError(message: Twine(attribute.getName()) + |
201 | " must be a dense i32 array attribute" ); |
202 | SmallVector<llvm::Metadata *, 3> metadata; |
203 | llvm::Type *i32 = llvm::IntegerType::get(C&: llvmContext, NumBits: 32); |
204 | for (int32_t i : value.asArrayRef()) { |
205 | llvm::Constant *constant = llvm::ConstantInt::get(i32, i); |
206 | metadata.push_back(llvm::ConstantAsMetadata::get(constant)); |
207 | } |
208 | llvm::Function *llvmFunc = |
209 | moduleTranslation.lookupFunction(name: func.getName()); |
210 | llvm::MDNode *node = llvm::MDNode::get(Context&: llvmContext, MDs: metadata); |
211 | llvmFunc->setMetadata(Kind: "reqd_work_group_size" , Node: node); |
212 | } |
213 | |
214 | // Atomic and nontemporal metadata |
215 | if (dialect->getLastUseAttrHelper().getName() == attribute.getName()) { |
216 | for (llvm::Instruction *i : instructions) |
217 | i->setMetadata("amdgpu.last.use" , llvm::MDNode::get(llvmContext, {})); |
218 | } |
219 | if (dialect->getNoRemoteMemoryAttrHelper().getName() == |
220 | attribute.getName()) { |
221 | for (llvm::Instruction *i : instructions) |
222 | i->setMetadata("amdgpu.no.remote.memory" , |
223 | llvm::MDNode::get(llvmContext, {})); |
224 | } |
225 | if (dialect->getNoFineGrainedMemoryAttrHelper().getName() == |
226 | attribute.getName()) { |
227 | for (llvm::Instruction *i : instructions) |
228 | i->setMetadata("amdgpu.no.fine.grained.memory" , |
229 | llvm::MDNode::get(llvmContext, {})); |
230 | } |
231 | if (dialect->getIgnoreDenormalModeAttrHelper().getName() == |
232 | attribute.getName()) { |
233 | for (llvm::Instruction *i : instructions) |
234 | i->setMetadata("amdgpu.ignore.denormal.mode" , |
235 | llvm::MDNode::get(llvmContext, {})); |
236 | } |
237 | |
238 | return success(); |
239 | } |
240 | }; |
241 | } // namespace |
242 | |
243 | void mlir::registerROCDLDialectTranslation(DialectRegistry ®istry) { |
244 | registry.insert<ROCDL::ROCDLDialect>(); |
245 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) { |
246 | dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>(); |
247 | }); |
248 | } |
249 | |
250 | void mlir::registerROCDLDialectTranslation(MLIRContext &context) { |
251 | DialectRegistry registry; |
252 | registerROCDLDialectTranslation(registry); |
253 | context.appendDialectRegistry(registry); |
254 | } |
255 | |