1//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate ROCDLIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
16#include "mlir/Dialect/Arith/Transforms/Passes.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Pass/PassManager.h"
19#include "mlir/Transforms/Passes.h"
20
21#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
22#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
23#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
25#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
26#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
27#include "mlir/Conversion/LLVMCommon/Pattern.h"
28#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
29#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
30#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
31#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
32#include "mlir/Dialect/Func/IR/FuncOps.h"
33#include "mlir/Dialect/GPU/IR/GPUDialect.h"
34#include "mlir/Dialect/GPU/Transforms/Passes.h"
35#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
36#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
37#include "mlir/Dialect/Math/IR/Math.h"
38#include "mlir/Dialect/MemRef/IR/MemRef.h"
39#include "mlir/Dialect/Vector/IR/VectorOps.h"
40#include "mlir/IR/BuiltinAttributes.h"
41#include "mlir/Pass/Pass.h"
42#include "mlir/Transforms/DialectConversion.h"
43#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
44#include "llvm/Support/FormatVariadic.h"
45
46#include "../GPUCommon/GPUOpsLowering.h"
47#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
48#include "../GPUCommon/OpToFuncCallLowering.h"
49
50namespace mlir {
51#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
52#include "mlir/Conversion/Passes.h.inc"
53} // namespace mlir
54
55using namespace mlir;
56
57/// Returns true if the given `gpu.func` can be safely called using the bare
58/// pointer calling convention.
59static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
60 bool canBeBare = true;
61 for (Type type : func.getArgumentTypes())
62 if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
63 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
64 return canBeBare;
65}
66
67Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
68 const unsigned indexBitwidth) {
69 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
70 Value zero = rewriter.create<arith::ConstantIntOp>(location: loc, args: 0, args: 32);
71 Value minus1 = rewriter.create<arith::ConstantIntOp>(location: loc, args: -1, args: 32);
72 Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
73 ValueRange{minus1, zero});
74 Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
75 ValueRange{minus1, mbcntLo});
76 return laneId;
77}
78static constexpr StringLiteral amdgcnDataLayout =
79 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
80 "-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:"
81 "128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-"
82 "G1-ni:7:8";
83
84namespace {
85struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
86 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
87
88 LogicalResult
89 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
90 ConversionPatternRewriter &rewriter) const override {
91 auto loc = op->getLoc();
92 MLIRContext *context = rewriter.getContext();
93 // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
94 // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
95
96 Type intTy = IntegerType::get(context, 32);
97 Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
98 Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
99 Value mbcntLo =
100 rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
101 Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
102 loc, intTy, ValueRange{minus1, mbcntLo});
103 // Truncate or extend the result depending on the index bitwidth specified
104 // by the LLVMTypeConverter options.
105 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
106 if (indexBitwidth > 32) {
107 laneId = rewriter.create<LLVM::SExtOp>(
108 loc, IntegerType::get(context, indexBitwidth), laneId);
109 } else if (indexBitwidth < 32) {
110 laneId = rewriter.create<LLVM::TruncOp>(
111 loc, IntegerType::get(context, indexBitwidth), laneId);
112 }
113 rewriter.replaceOp(op, {laneId});
114 return success();
115 }
116};
117
118struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
119 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
120
121 /// Lowers a shuffle to the corresponding ROCDL ops.
122 ///
123 /// Use the `width` argument to see if src lane is participating.
124 /// If not the dstLane would be itself.
125 ///
126 /// Shuffle with DS Bpermute:
127 /// let shflMode = [xor, up, down, idx]
128 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
129 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi
130 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
131 /// 3. dstLane = shflMode(curLaneId, step)
132 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane
133 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
134 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
135 /// 7. bpermute(dwordAlignedDstLane, shfl_value).
136 ///
137 LogicalResult
138 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
139 ConversionPatternRewriter &rewriter) const override {
140 Location loc = op->getLoc();
141 // TODO: Add support for non 32-bit shuffle values.
142 if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
143 return failure();
144 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
145 Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
146
147 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
148 Value width = adaptor.getWidth();
149 Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
150 Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
151 Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
152 Value widthOrZeroIfOutside =
153 rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
154 Value dstLane;
155 // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
156 // TODO: Use ds_swizzle for XOR when step/offsets are constants for better
157 // perf.
158 switch (op.getMode()) {
159 case gpu::ShuffleMode::XOR:
160 dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
161 adaptor.getOffset());
162 break;
163 case gpu::ShuffleMode::IDX:
164 dstLane = adaptor.getOffset();
165 break;
166 default:
167 return failure();
168 }
169 Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
170 loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
171 Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
172 dstLane, srcLaneId);
173 Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
174 Value dwordAlignedDstLane =
175 rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
176 Value initShflValue = adaptor.getValue();
177 if (adaptor.getValue().getType().isF32()) {
178 initShflValue =
179 rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
180 }
181 Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
182 loc, int32Type, dwordAlignedDstLane, initShflValue);
183 if (adaptor.getValue().getType().isF32()) {
184 shflValue = rewriter.create<LLVM::BitcastOp>(
185 loc, adaptor.getValue().getType(), shflValue);
186 }
187 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
188 return success();
189 }
190};
191
192/// Import the GPU Ops to ROCDL Patterns.
193#include "GPUToROCDL.cpp.inc"
194
195// A pass that replaces all occurrences of GPU device operations with their
196// corresponding ROCDL equivalent.
197//
198// This pass only handles device code and is not meant to be run on GPU host
199// code.
200struct LowerGpuOpsToROCDLOpsPass
201 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
202 LowerGpuOpsToROCDLOpsPass() = default;
203 LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
204 bool useBarePtrCallConv,
205 gpu::amd::Runtime runtime) {
206 if (this->chipset.getNumOccurrences() == 0)
207 this->chipset = chipset;
208 if (this->indexBitwidth.getNumOccurrences() == 0)
209 this->indexBitwidth = indexBitwidth;
210 if (this->useBarePtrCallConv.getNumOccurrences() == 0)
211 this->useBarePtrCallConv = useBarePtrCallConv;
212 if (this->runtime.getNumOccurrences() == 0)
213 this->runtime = runtime;
214 }
215
216 void runOnOperation() override {
217 gpu::GPUModuleOp m = getOperation();
218 MLIRContext *ctx = m.getContext();
219
220 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
221 LLVM::LLVMDialect::getDataLayoutAttrName());
222 if (!llvmDataLayout) {
223 llvmDataLayout = StringAttr::get(ctx, amdgcnDataLayout);
224 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), llvmDataLayout);
225 }
226 // Request C wrapper emission.
227 for (auto func : m.getOps<func::FuncOp>()) {
228 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
229 UnitAttr::get(ctx));
230 }
231
232 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
233 if (failed(maybeChipset)) {
234 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
235 return signalPassFailure();
236 }
237
238 /// Customize the bitwidth used for the device side index computations.
239 LowerToLLVMOptions options(
240 ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
241 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
242 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
243 options.overrideIndexBitwidth(indexBitwidth);
244
245 if (useBarePtrCallConv) {
246 options.useBarePtrCallConv = true;
247 WalkResult canUseBarePointers =
248 m.walk([](gpu::GPUFuncOp func) -> WalkResult {
249 if (canBeCalledWithBarePointers(func))
250 return WalkResult::advance();
251 return WalkResult::interrupt();
252 });
253 if (canUseBarePointers.wasInterrupted()) {
254 emitError(UnknownLoc::get(ctx),
255 "bare pointer calling convention requires all memrefs to "
256 "have static shape and use the identity map");
257 return signalPassFailure();
258 }
259 }
260
261 // Apply in-dialect lowering. In-dialect lowering will replace
262 // ops which need to be lowered further, which is not supported by a
263 // single conversion pass.
264 {
265 RewritePatternSet patterns(ctx);
266 populateGpuRewritePatterns(patterns);
267 arith::populateExpandBFloat16Patterns(patterns);
268 (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
269 }
270
271 LLVMTypeConverter converter(ctx, options);
272 populateGpuMemorySpaceAttributeConversions(
273 converter, [](gpu::AddressSpace space) {
274 switch (space) {
275 case gpu::AddressSpace::Global:
276 return 1;
277 case gpu::AddressSpace::Workgroup:
278 return 3;
279 case gpu::AddressSpace::Private:
280 return 5;
281 }
282 llvm_unreachable("unknown address space enum value");
283 return 0;
284 });
285
286 RewritePatternSet llvmPatterns(ctx);
287
288 mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
289 populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
290 *maybeChipset);
291 populateVectorToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
292 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
293 populateFuncToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
294 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
295 populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
296 LLVMConversionTarget target(getContext());
297 configureGpuToROCDLConversionLegality(target);
298 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
299 signalPassFailure();
300 auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
301 auto reqdWorkGroupSizeAttrHelper =
302 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
303 auto flatWorkGroupSizeAttrHelper =
304 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
305 // Manually rewrite known block size attributes so the LLVMIR translation
306 // infrastructure can pick them up.
307 m.walk([&](LLVM::LLVMFuncOp op) {
308 if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
309 op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
310 reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
311 // Also set up the rocdl.flat_work_group_size attribute to prevent
312 // conflicting metadata.
313 uint32_t flatSize = 1;
314 for (uint32_t size : blockSizes.asArrayRef()) {
315 flatSize *= size;
316 }
317 StringAttr flatSizeAttr =
318 StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
319 flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
320 }
321 });
322 }
323};
324
325} // namespace
326
327void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
328 target.addIllegalOp<func::FuncOp>();
329 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
330 target.addLegalDialect<ROCDL::ROCDLDialect>();
331 target.addIllegalDialect<gpu::GPUDialect>();
332 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
333 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
334 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
335 LLVM::SqrtOp>();
336
337 // TODO: Remove once we support replacing non-root ops.
338 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
339}
340
341template <typename OpTy>
342static void populateOpPatterns(LLVMTypeConverter &converter,
343 RewritePatternSet &patterns, StringRef f32Func,
344 StringRef f64Func) {
345 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
346 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
347}
348
349void mlir::populateGpuToROCDLConversionPatterns(
350 LLVMTypeConverter &converter, RewritePatternSet &patterns,
351 mlir::gpu::amd::Runtime runtime) {
352 using mlir::gpu::amd::Runtime;
353
354 populateWithGenerated(patterns);
355 patterns
356 .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
357 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
358 converter, gpu::GPUFuncOp::getKnownBlockSizeAttrName());
359 patterns.add<GPUIndexIntrinsicOpLowering<
360 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
361 converter, gpu::GPUFuncOp::getKnownGridSizeAttrName());
362 patterns
363 .add<GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
364 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>,
365 GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp,
366 ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
367 GPUReturnOpLowering>(converter);
368 patterns.add<GPUFuncOpLowering>(
369 converter,
370 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
371 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
372 ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
373 if (Runtime::HIP == runtime) {
374 patterns.add<GPUPrintfOpToHIPLowering>(arg&: converter);
375 } else if (Runtime::OpenCL == runtime) {
376 // Use address space = 4 to match the OpenCL definition of printf()
377 patterns.add<GPUPrintfOpToLLVMCallLowering>(arg&: converter, /*addressSpace=*/args: 4);
378 }
379 // TODO: Add alignment for workgroup memory
380 patterns.add<GPUDynamicSharedMemoryOpLowering>(arg&: converter);
381
382 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(arg&: converter);
383
384 populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
385 "__ocml_fabs_f64");
386 populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
387 "__ocml_atan_f64");
388 populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
389 "__ocml_atan2_f64");
390 populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
391 "__ocml_cbrt_f64");
392 populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
393 "__ocml_ceil_f64");
394 populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
395 "__ocml_cos_f64");
396 populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
397 "__ocml_exp_f64");
398 populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
399 "__ocml_exp2_f64");
400 populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
401 "__ocml_expm1_f64");
402 populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
403 "__ocml_floor_f64");
404 populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
405 "__ocml_fmod_f64");
406 populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
407 "__ocml_log_f64");
408 populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
409 "__ocml_log10_f64");
410 populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
411 "__ocml_log1p_f64");
412 populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
413 "__ocml_log2_f64");
414 populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
415 "__ocml_pow_f64");
416 populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
417 "__ocml_rsqrt_f64");
418 populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
419 "__ocml_sin_f64");
420 populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
421 "__ocml_sqrt_f64");
422 populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
423 "__ocml_tanh_f64");
424 populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
425 "__ocml_tan_f64");
426 populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
427 "__ocml_erf_f64");
428}
429
430std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
431mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
432 unsigned indexBitwidth,
433 bool useBarePtrCallConv,
434 gpu::amd::Runtime runtime) {
435 return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
436 args: chipset, args&: indexBitwidth, args&: useBarePtrCallConv, args&: runtime);
437}
438

source code of mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp