| 1 | //===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a pass to generate ROCDLIR operations for higher-level |
| 10 | // GPU operations. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" |
| 15 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
| 16 | #include "mlir/Pass/Pass.h" |
| 17 | #include "mlir/Pass/PassManager.h" |
| 18 | |
| 19 | #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" |
| 20 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| 21 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" |
| 22 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
| 23 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| 24 | #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
| 25 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 26 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 27 | #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" |
| 28 | #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" |
| 29 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| 30 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
| 31 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 32 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 33 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 34 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 35 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| 36 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 37 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 38 | #include "mlir/IR/BuiltinAttributes.h" |
| 39 | #include "mlir/Pass/Pass.h" |
| 40 | #include "mlir/Transforms/DialectConversion.h" |
| 41 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 42 | |
| 43 | #include "../GPUCommon/GPUOpsLowering.h" |
| 44 | #include "../GPUCommon/IndexIntrinsicsOpLowering.h" |
| 45 | |
| 46 | namespace mlir { |
| 47 | #define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS |
| 48 | #include "mlir/Conversion/Passes.h.inc" |
| 49 | } // namespace mlir |
| 50 | |
| 51 | using namespace mlir; |
| 52 | |
| 53 | // Truncate or extend the result depending on the index bitwidth specified |
| 54 | // by the LLVMTypeConverter options. |
| 55 | static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, |
| 56 | Location loc, Value value, |
| 57 | const LLVMTypeConverter &converter) { |
| 58 | int64_t intWidth = cast<IntegerType>(Val: value.getType()).getWidth(); |
| 59 | int64_t indexBitwidth = converter.getIndexTypeBitwidth(); |
| 60 | auto indexBitwidthType = |
| 61 | IntegerType::get(context: rewriter.getContext(), width: converter.getIndexTypeBitwidth()); |
| 62 | // TODO: use <=> in C++20. |
| 63 | if (indexBitwidth > intWidth) { |
| 64 | return rewriter.create<LLVM::SExtOp>(location: loc, args&: indexBitwidthType, args&: value); |
| 65 | } |
| 66 | if (indexBitwidth < intWidth) { |
| 67 | return rewriter.create<LLVM::TruncOp>(location: loc, args&: indexBitwidthType, args&: value); |
| 68 | } |
| 69 | return value; |
| 70 | } |
| 71 | |
| 72 | /// Returns true if the given `gpu.func` can be safely called using the bare |
| 73 | /// pointer calling convention. |
| 74 | static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { |
| 75 | bool canBeBare = true; |
| 76 | for (Type type : func.getArgumentTypes()) |
| 77 | if (auto memrefTy = dyn_cast<BaseMemRefType>(Val&: type)) |
| 78 | canBeBare &= LLVMTypeConverter::canConvertToBarePtr(type: memrefTy); |
| 79 | return canBeBare; |
| 80 | } |
| 81 | |
| 82 | static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, |
| 83 | const unsigned indexBitwidth) { |
| 84 | auto int32Type = IntegerType::get(context: rewriter.getContext(), width: 32); |
| 85 | Value zero = rewriter.create<arith::ConstantIntOp>(location: loc, args: 0, args: 32); |
| 86 | Value minus1 = rewriter.create<arith::ConstantIntOp>(location: loc, args: -1, args: 32); |
| 87 | Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(location: loc, args&: int32Type, |
| 88 | args: ValueRange{minus1, zero}); |
| 89 | Value laneId = rewriter.create<ROCDL::MbcntHiOp>(location: loc, args&: int32Type, |
| 90 | args: ValueRange{minus1, mbcntLo}); |
| 91 | return laneId; |
| 92 | } |
| 93 | static constexpr StringLiteral amdgcnDataLayout = |
| 94 | "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32" |
| 95 | "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:" |
| 96 | "32-v32:" |
| 97 | "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:" |
| 98 | "64-S32-A5-G1-ni:7:8:9" ; |
| 99 | |
| 100 | namespace { |
| 101 | struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { |
| 102 | using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern; |
| 103 | |
| 104 | LogicalResult |
| 105 | matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, |
| 106 | ConversionPatternRewriter &rewriter) const override { |
| 107 | auto loc = op->getLoc(); |
| 108 | MLIRContext *context = rewriter.getContext(); |
| 109 | // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0) |
| 110 | // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) |
| 111 | |
| 112 | Type intTy = IntegerType::get(context, width: 32); |
| 113 | Value zero = rewriter.create<arith::ConstantIntOp>(location: loc, args: 0, args: 32); |
| 114 | Value minus1 = rewriter.create<arith::ConstantIntOp>(location: loc, args: -1, args: 32); |
| 115 | Value mbcntLo = |
| 116 | rewriter.create<ROCDL::MbcntLoOp>(location: loc, args&: intTy, args: ValueRange{minus1, zero}); |
| 117 | Value laneId = rewriter.create<ROCDL::MbcntHiOp>( |
| 118 | location: loc, args&: intTy, args: ValueRange{minus1, mbcntLo}); |
| 119 | // Truncate or extend the result depending on the index bitwidth specified |
| 120 | // by the LLVMTypeConverter options. |
| 121 | const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); |
| 122 | if (indexBitwidth > 32) { |
| 123 | laneId = rewriter.create<LLVM::SExtOp>( |
| 124 | location: loc, args: IntegerType::get(context, width: indexBitwidth), args&: laneId); |
| 125 | } else if (indexBitwidth < 32) { |
| 126 | laneId = rewriter.create<LLVM::TruncOp>( |
| 127 | location: loc, args: IntegerType::get(context, width: indexBitwidth), args&: laneId); |
| 128 | } |
| 129 | rewriter.replaceOp(op, newValues: {laneId}); |
| 130 | return success(); |
| 131 | } |
| 132 | }; |
| 133 | |
| 134 | struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> { |
| 135 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 136 | |
| 137 | GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter, |
| 138 | amdgpu::Chipset chipset) |
| 139 | : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp>(converter), |
| 140 | chipset(chipset) {} |
| 141 | |
| 142 | LogicalResult |
| 143 | matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor, |
| 144 | ConversionPatternRewriter &rewriter) const override { |
| 145 | LLVM::ConstantRangeAttr bounds = nullptr; |
| 146 | bool isBeforeGfx10 = chipset.majorVersion < 10; |
| 147 | if (auto upperBoundAttr = op.getUpperBoundAttr()) { |
| 148 | bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>( |
| 149 | /*bitWidth=*/args: 32, /*lower=*/args: isBeforeGfx10 ? 64 : 32, |
| 150 | /*upper=*/args: op.getUpperBoundAttr().getInt() + 1); |
| 151 | } |
| 152 | Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>( |
| 153 | location: op.getLoc(), args: rewriter.getI32Type(), args&: bounds); |
| 154 | wavefrontOp = truncOrExtToLLVMType(rewriter, loc: op.getLoc(), value: wavefrontOp, |
| 155 | converter: *getTypeConverter()); |
| 156 | rewriter.replaceOp(op, newValues: {wavefrontOp}); |
| 157 | return success(); |
| 158 | } |
| 159 | |
| 160 | const amdgpu::Chipset chipset; |
| 161 | }; |
| 162 | |
| 163 | struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { |
| 164 | using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern; |
| 165 | |
| 166 | /// Lowers a shuffle to the corresponding ROCDL ops. |
| 167 | /// |
| 168 | /// Use the `width` argument to see if src lane is participating. |
| 169 | /// If not the dstLane would be itself. |
| 170 | /// |
| 171 | /// Shuffle with DS Bpermute: |
| 172 | /// let shflMode = [xor, up, down, idx] |
| 173 | /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width]. |
| 174 | /// 1. curLaneId = using mbcnt.lo + mbcnt.hi |
| 175 | /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width |
| 176 | /// 3. dstLane = shflMode(curLaneId, step) |
| 177 | /// 4. isActiveSrcLane = dstLane < isActiveSrcLane |
| 178 | /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId |
| 179 | /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2. |
| 180 | /// 7. bpermute(dwordAlignedDstLane, shfl_value). |
| 181 | /// |
| 182 | LogicalResult |
| 183 | matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, |
| 184 | ConversionPatternRewriter &rewriter) const override { |
| 185 | Location loc = op->getLoc(); |
| 186 | Value initShflValue = adaptor.getValue(); |
| 187 | |
| 188 | const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); |
| 189 | Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth); |
| 190 | |
| 191 | auto int32Type = IntegerType::get(context: rewriter.getContext(), width: 32); |
| 192 | Value width = adaptor.getWidth(); |
| 193 | Value zero = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int32Type, args: 0); |
| 194 | Value negwidth = rewriter.create<LLVM::SubOp>(location: loc, args&: int32Type, args&: zero, args&: width); |
| 195 | Value add = rewriter.create<LLVM::AddOp>(location: loc, args&: int32Type, args&: srcLaneId, args&: width); |
| 196 | Value widthOrZeroIfOutside = |
| 197 | rewriter.create<LLVM::AndOp>(location: loc, args&: int32Type, args&: add, args&: negwidth); |
| 198 | Value dstLane; |
| 199 | |
| 200 | switch (op.getMode()) { |
| 201 | case gpu::ShuffleMode::UP: |
| 202 | dstLane = rewriter.create<LLVM::SubOp>(location: loc, args&: int32Type, args&: srcLaneId, |
| 203 | args: adaptor.getOffset()); |
| 204 | break; |
| 205 | case gpu::ShuffleMode::DOWN: |
| 206 | dstLane = rewriter.create<LLVM::AddOp>(location: loc, args&: int32Type, args&: srcLaneId, |
| 207 | args: adaptor.getOffset()); |
| 208 | break; |
| 209 | case gpu::ShuffleMode::XOR: |
| 210 | dstLane = rewriter.create<LLVM::XOrOp>(location: loc, args&: int32Type, args&: srcLaneId, |
| 211 | args: adaptor.getOffset()); |
| 212 | break; |
| 213 | case gpu::ShuffleMode::IDX: |
| 214 | dstLane = adaptor.getOffset(); |
| 215 | break; |
| 216 | } |
| 217 | Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>( |
| 218 | location: loc, args: LLVM::ICmpPredicate::slt, args&: dstLane, args&: widthOrZeroIfOutside); |
| 219 | Value selectDstLane = rewriter.create<LLVM::SelectOp>(location: loc, args&: isActiveSrcLane, |
| 220 | args&: dstLane, args&: srcLaneId); |
| 221 | Value two = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int32Type, args: 2); |
| 222 | Value dwordAlignedDstLane = |
| 223 | rewriter.create<LLVM::ShlOp>(location: loc, args&: int32Type, args&: selectDstLane, args&: two); |
| 224 | |
| 225 | SmallVector<Value> decomposed = |
| 226 | LLVM::decomposeValue(builder&: rewriter, loc, src: initShflValue, dstType: int32Type); |
| 227 | SmallVector<Value> swizzled; |
| 228 | for (Value v : decomposed) { |
| 229 | Value res = rewriter.create<ROCDL::DsBpermuteOp>(location: loc, args&: int32Type, |
| 230 | args&: dwordAlignedDstLane, args&: v); |
| 231 | swizzled.emplace_back(Args&: res); |
| 232 | } |
| 233 | Value shflValue = |
| 234 | LLVM::composeValue(builder&: rewriter, loc, src: swizzled, dstType: initShflValue.getType()); |
| 235 | rewriter.replaceOp(op, newValues: {shflValue, isActiveSrcLane}); |
| 236 | return success(); |
| 237 | } |
| 238 | }; |
| 239 | |
| 240 | /// Import the GPU Ops to ROCDL Patterns. |
| 241 | #include "GPUToROCDL.cpp.inc" |
| 242 | |
| 243 | // A pass that replaces all occurrences of GPU device operations with their |
| 244 | // corresponding ROCDL equivalent. |
| 245 | // |
| 246 | // This pass only handles device code and is not meant to be run on GPU host |
| 247 | // code. |
| 248 | struct LowerGpuOpsToROCDLOpsPass final |
| 249 | : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> { |
| 250 | LowerGpuOpsToROCDLOpsPass() = default; |
| 251 | LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, |
| 252 | bool useBarePtrCallConv, |
| 253 | gpu::amd::Runtime runtime) { |
| 254 | if (this->chipset.getNumOccurrences() == 0) |
| 255 | this->chipset = chipset; |
| 256 | if (this->indexBitwidth.getNumOccurrences() == 0) |
| 257 | this->indexBitwidth = indexBitwidth; |
| 258 | if (this->useBarePtrCallConv.getNumOccurrences() == 0) |
| 259 | this->useBarePtrCallConv = useBarePtrCallConv; |
| 260 | if (this->runtime.getNumOccurrences() == 0) |
| 261 | this->runtime = runtime; |
| 262 | } |
| 263 | |
| 264 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 265 | Base::getDependentDialects(registry); |
| 266 | registerConvertToLLVMDependentDialectLoading(registry); |
| 267 | } |
| 268 | |
| 269 | void runOnOperation() override { |
| 270 | gpu::GPUModuleOp m = getOperation(); |
| 271 | MLIRContext *ctx = m.getContext(); |
| 272 | |
| 273 | auto llvmDataLayout = m->getAttrOfType<StringAttr>( |
| 274 | name: LLVM::LLVMDialect::getDataLayoutAttrName()); |
| 275 | if (!llvmDataLayout) { |
| 276 | llvmDataLayout = StringAttr::get(context: ctx, bytes: amdgcnDataLayout); |
| 277 | m->setAttr(name: LLVM::LLVMDialect::getDataLayoutAttrName(), value: llvmDataLayout); |
| 278 | } |
| 279 | // Request C wrapper emission. |
| 280 | for (auto func : m.getOps<func::FuncOp>()) { |
| 281 | func->setAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
| 282 | value: UnitAttr::get(context: ctx)); |
| 283 | } |
| 284 | |
| 285 | FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(name: chipset); |
| 286 | if (failed(Result: maybeChipset)) { |
| 287 | emitError(loc: UnknownLoc::get(context: ctx), message: "Invalid chipset name: " + chipset); |
| 288 | return signalPassFailure(); |
| 289 | } |
| 290 | |
| 291 | /// Customize the bitwidth used for the device side index computations. |
| 292 | LowerToLLVMOptions options( |
| 293 | ctx, DataLayout(cast<DataLayoutOpInterface>(Val: m.getOperation()))); |
| 294 | options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue()); |
| 295 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
| 296 | options.overrideIndexBitwidth(bitwidth: indexBitwidth); |
| 297 | |
| 298 | if (useBarePtrCallConv) { |
| 299 | options.useBarePtrCallConv = true; |
| 300 | WalkResult canUseBarePointers = |
| 301 | m.walk(callback: [](gpu::GPUFuncOp func) -> WalkResult { |
| 302 | if (canBeCalledWithBarePointers(func)) |
| 303 | return WalkResult::advance(); |
| 304 | return WalkResult::interrupt(); |
| 305 | }); |
| 306 | if (canUseBarePointers.wasInterrupted()) { |
| 307 | emitError(loc: UnknownLoc::get(context: ctx), |
| 308 | message: "bare pointer calling convention requires all memrefs to " |
| 309 | "have static shape and use the identity map" ); |
| 310 | return signalPassFailure(); |
| 311 | } |
| 312 | } |
| 313 | |
| 314 | // Apply in-dialect lowering. In-dialect lowering will replace |
| 315 | // ops which need to be lowered further, which is not supported by a |
| 316 | // single conversion pass. |
| 317 | { |
| 318 | RewritePatternSet patterns(ctx); |
| 319 | populateGpuRewritePatterns(patterns); |
| 320 | populateGpuPromoteShuffleToAMDGPUPatterns(patterns); |
| 321 | (void)applyPatternsGreedily(op: m, patterns: std::move(patterns)); |
| 322 | } |
| 323 | |
| 324 | LLVMTypeConverter converter(ctx, options); |
| 325 | populateGpuMemorySpaceAttributeConversions( |
| 326 | typeConverter&: converter, mapping: [](gpu::AddressSpace space) { |
| 327 | switch (space) { |
| 328 | case gpu::AddressSpace::Global: |
| 329 | return 1; |
| 330 | case gpu::AddressSpace::Workgroup: |
| 331 | return 3; |
| 332 | case gpu::AddressSpace::Private: |
| 333 | return 5; |
| 334 | } |
| 335 | llvm_unreachable("unknown address space enum value" ); |
| 336 | return 0; |
| 337 | }); |
| 338 | |
| 339 | RewritePatternSet llvmPatterns(ctx); |
| 340 | LLVMConversionTarget target(getContext()); |
| 341 | |
| 342 | llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(), |
| 343 | allowedDialects.end()); |
| 344 | for (Dialect *dialect : ctx->getLoadedDialects()) { |
| 345 | bool allowed = allowedDialectsSet.contains(V: dialect->getNamespace()); |
| 346 | // Empty `allowedDialectsSet` means all dialects are allowed. |
| 347 | if (!allowedDialectsSet.empty() && !allowed) |
| 348 | continue; |
| 349 | |
| 350 | auto iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect); |
| 351 | if (!iface) { |
| 352 | // Error out if dialect was explicily specified but doesn't implement |
| 353 | // conversion interface. |
| 354 | if (allowed) { |
| 355 | m.emitError() |
| 356 | << "dialect does not implement ConvertToLLVMPatternInterface: " |
| 357 | << dialect->getNamespace(); |
| 358 | return signalPassFailure(); |
| 359 | } |
| 360 | continue; |
| 361 | } |
| 362 | |
| 363 | iface->populateConvertToLLVMConversionPatterns(target, typeConverter&: converter, |
| 364 | patterns&: llvmPatterns); |
| 365 | } |
| 366 | |
| 367 | populateAMDGPUToROCDLConversionPatterns(converter, patterns&: llvmPatterns, |
| 368 | chipset: *maybeChipset); |
| 369 | populateGpuToROCDLConversionPatterns(converter, patterns&: llvmPatterns, runtime, |
| 370 | chipset: *maybeChipset); |
| 371 | configureGpuToROCDLConversionLegality(target); |
| 372 | if (failed(Result: applyPartialConversion(op: m, target, patterns: std::move(llvmPatterns)))) |
| 373 | signalPassFailure(); |
| 374 | auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>(); |
| 375 | auto reqdWorkGroupSizeAttrHelper = |
| 376 | rocdlDialect->getReqdWorkGroupSizeAttrHelper(); |
| 377 | auto flatWorkGroupSizeAttrHelper = |
| 378 | rocdlDialect->getFlatWorkGroupSizeAttrHelper(); |
| 379 | // Manually rewrite known block size attributes so the LLVMIR translation |
| 380 | // infrastructure can pick them up. |
| 381 | m.walk(callback: [&](LLVM::LLVMFuncOp op) { |
| 382 | if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) { |
| 383 | auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op); |
| 384 | // Also set up the rocdl.flat_work_group_size attribute to prevent |
| 385 | // conflicting metadata. |
| 386 | uint32_t flatSize = 1; |
| 387 | for (uint32_t size : blockSizes.asArrayRef()) { |
| 388 | flatSize *= size; |
| 389 | } |
| 390 | StringAttr flatSizeAttr = |
| 391 | StringAttr::get(context: ctx, bytes: Twine(flatSize) + "," + Twine(flatSize)); |
| 392 | flatWorkGroupSizeAttrHelper.setAttr(op, val: flatSizeAttr); |
| 393 | } |
| 394 | }); |
| 395 | } |
| 396 | }; |
| 397 | |
| 398 | } // namespace |
| 399 | |
| 400 | void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { |
| 401 | target.addIllegalOp<func::FuncOp>(); |
| 402 | target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
| 403 | target.addLegalDialect<ROCDL::ROCDLDialect>(); |
| 404 | target.addIllegalDialect<gpu::GPUDialect>(); |
| 405 | target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp, |
| 406 | LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, |
| 407 | LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>(); |
| 408 | // These ops are legal for f32 type. |
| 409 | target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>(callback: [](Operation *op) { |
| 410 | return any_of(Range: op->getOperandTypes(), P: llvm::IsaPred<Float32Type>); |
| 411 | }); |
| 412 | // TODO: Remove once we support replacing non-root ops. |
| 413 | target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>(); |
| 414 | } |
| 415 | |
| 416 | void mlir::populateGpuToROCDLConversionPatterns( |
| 417 | const LLVMTypeConverter &converter, RewritePatternSet &patterns, |
| 418 | mlir::gpu::amd::Runtime runtime, amdgpu::Chipset chipset) { |
| 419 | using gpu::index_lowering::IndexKind; |
| 420 | using gpu::index_lowering::IntrType; |
| 421 | using mlir::gpu::amd::Runtime; |
| 422 | auto *rocdlDialect = |
| 423 | converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>(); |
| 424 | populateWithGenerated(patterns); |
| 425 | patterns.add< |
| 426 | gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp, |
| 427 | ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>( |
| 428 | arg: converter, args: IndexKind::Block, args: IntrType::Id); |
| 429 | patterns.add<gpu::index_lowering::OpLowering< |
| 430 | gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>( |
| 431 | arg: converter, args: IndexKind::Grid, args: IntrType::Id); |
| 432 | patterns.add< |
| 433 | gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp, |
| 434 | ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>( |
| 435 | arg: converter, args: IndexKind::Block, args: IntrType::Dim); |
| 436 | patterns.add<gpu::index_lowering::OpLowering< |
| 437 | gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>( |
| 438 | arg: converter, args: IndexKind::Grid, args: IntrType::Dim); |
| 439 | patterns.add<GPUReturnOpLowering>(arg: converter); |
| 440 | patterns.add<GPUFuncOpLowering>( |
| 441 | arg: converter, |
| 442 | args: GPUFuncOpLoweringOptions{ |
| 443 | /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace, |
| 444 | /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace, |
| 445 | .kernelAttributeName: rocdlDialect->getKernelAttrHelper().getName(), |
| 446 | .kernelBlockSizeAttributeName: rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName()}); |
| 447 | if (Runtime::HIP == runtime) { |
| 448 | patterns.add<GPUPrintfOpToHIPLowering>(arg: converter); |
| 449 | } else if (Runtime::OpenCL == runtime) { |
| 450 | // Use address space = 4 to match the OpenCL definition of printf() |
| 451 | patterns.add<GPUPrintfOpToLLVMCallLowering>(arg: converter, /*addressSpace=*/args: 4); |
| 452 | } |
| 453 | // TODO: Add alignment for workgroup memory |
| 454 | patterns.add<GPUDynamicSharedMemoryOpLowering>(arg: converter); |
| 455 | |
| 456 | patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(arg: converter); |
| 457 | patterns.add<GPUSubgroupSizeOpToROCDL>(arg: converter, args&: chipset); |
| 458 | |
| 459 | populateMathToROCDLConversionPatterns(converter, patterns); |
| 460 | } |
| 461 | |
| 462 | std::unique_ptr<OperationPass<gpu::GPUModuleOp>> |
| 463 | mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset, |
| 464 | unsigned indexBitwidth, |
| 465 | bool useBarePtrCallConv, |
| 466 | gpu::amd::Runtime runtime) { |
| 467 | return std::make_unique<LowerGpuOpsToROCDLOpsPass>( |
| 468 | args: chipset, args&: indexBitwidth, args&: useBarePtrCallConv, args&: runtime); |
| 469 | } |
| 470 | |