1//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate ROCDLIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
15#include "mlir/Dialect/Arith/Transforms/Passes.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Pass/PassManager.h"
18
19#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
20#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
21#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
22#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
23#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
24#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
25#include "mlir/Conversion/LLVMCommon/Pattern.h"
26#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
27#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
28#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
29#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
30#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
31#include "mlir/Dialect/Func/IR/FuncOps.h"
32#include "mlir/Dialect/GPU/IR/GPUDialect.h"
33#include "mlir/Dialect/GPU/Transforms/Passes.h"
34#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
35#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
36#include "mlir/Dialect/MemRef/IR/MemRef.h"
37#include "mlir/Dialect/Vector/IR/VectorOps.h"
38#include "mlir/IR/BuiltinAttributes.h"
39#include "mlir/Pass/Pass.h"
40#include "mlir/Transforms/DialectConversion.h"
41#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
42
43#include "../GPUCommon/GPUOpsLowering.h"
44#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
45
46namespace mlir {
47#define GEN_PASS_DEF_CONVERTGPUOPSTOROCDLOPS
48#include "mlir/Conversion/Passes.h.inc"
49} // namespace mlir
50
51using namespace mlir;
52
53// Truncate or extend the result depending on the index bitwidth specified
54// by the LLVMTypeConverter options.
55static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
56 Location loc, Value value,
57 const LLVMTypeConverter &converter) {
58 int64_t intWidth = cast<IntegerType>(Val: value.getType()).getWidth();
59 int64_t indexBitwidth = converter.getIndexTypeBitwidth();
60 auto indexBitwidthType =
61 IntegerType::get(context: rewriter.getContext(), width: converter.getIndexTypeBitwidth());
62 // TODO: use <=> in C++20.
63 if (indexBitwidth > intWidth) {
64 return rewriter.create<LLVM::SExtOp>(location: loc, args&: indexBitwidthType, args&: value);
65 }
66 if (indexBitwidth < intWidth) {
67 return rewriter.create<LLVM::TruncOp>(location: loc, args&: indexBitwidthType, args&: value);
68 }
69 return value;
70}
71
72/// Returns true if the given `gpu.func` can be safely called using the bare
73/// pointer calling convention.
74static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
75 bool canBeBare = true;
76 for (Type type : func.getArgumentTypes())
77 if (auto memrefTy = dyn_cast<BaseMemRefType>(Val&: type))
78 canBeBare &= LLVMTypeConverter::canConvertToBarePtr(type: memrefTy);
79 return canBeBare;
80}
81
82static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
83 const unsigned indexBitwidth) {
84 auto int32Type = IntegerType::get(context: rewriter.getContext(), width: 32);
85 Value zero = rewriter.create<arith::ConstantIntOp>(location: loc, args: 0, args: 32);
86 Value minus1 = rewriter.create<arith::ConstantIntOp>(location: loc, args: -1, args: 32);
87 Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(location: loc, args&: int32Type,
88 args: ValueRange{minus1, zero});
89 Value laneId = rewriter.create<ROCDL::MbcntHiOp>(location: loc, args&: int32Type,
90 args: ValueRange{minus1, mbcntLo});
91 return laneId;
92}
93static constexpr StringLiteral amdgcnDataLayout =
94 "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
95 "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
96 "32-v32:"
97 "32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:"
98 "64-S32-A5-G1-ni:7:8:9";
99
100namespace {
101struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
102 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
103
104 LogicalResult
105 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
106 ConversionPatternRewriter &rewriter) const override {
107 auto loc = op->getLoc();
108 MLIRContext *context = rewriter.getContext();
109 // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
110 // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
111
112 Type intTy = IntegerType::get(context, width: 32);
113 Value zero = rewriter.create<arith::ConstantIntOp>(location: loc, args: 0, args: 32);
114 Value minus1 = rewriter.create<arith::ConstantIntOp>(location: loc, args: -1, args: 32);
115 Value mbcntLo =
116 rewriter.create<ROCDL::MbcntLoOp>(location: loc, args&: intTy, args: ValueRange{minus1, zero});
117 Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
118 location: loc, args&: intTy, args: ValueRange{minus1, mbcntLo});
119 // Truncate or extend the result depending on the index bitwidth specified
120 // by the LLVMTypeConverter options.
121 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
122 if (indexBitwidth > 32) {
123 laneId = rewriter.create<LLVM::SExtOp>(
124 location: loc, args: IntegerType::get(context, width: indexBitwidth), args&: laneId);
125 } else if (indexBitwidth < 32) {
126 laneId = rewriter.create<LLVM::TruncOp>(
127 location: loc, args: IntegerType::get(context, width: indexBitwidth), args&: laneId);
128 }
129 rewriter.replaceOp(op, newValues: {laneId});
130 return success();
131 }
132};
133
134struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
135 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
136
137 GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter,
138 amdgpu::Chipset chipset)
139 : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp>(converter),
140 chipset(chipset) {}
141
142 LogicalResult
143 matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
144 ConversionPatternRewriter &rewriter) const override {
145 LLVM::ConstantRangeAttr bounds = nullptr;
146 bool isBeforeGfx10 = chipset.majorVersion < 10;
147 if (auto upperBoundAttr = op.getUpperBoundAttr()) {
148 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
149 /*bitWidth=*/args: 32, /*lower=*/args: isBeforeGfx10 ? 64 : 32,
150 /*upper=*/args: op.getUpperBoundAttr().getInt() + 1);
151 }
152 Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
153 location: op.getLoc(), args: rewriter.getI32Type(), args&: bounds);
154 wavefrontOp = truncOrExtToLLVMType(rewriter, loc: op.getLoc(), value: wavefrontOp,
155 converter: *getTypeConverter());
156 rewriter.replaceOp(op, newValues: {wavefrontOp});
157 return success();
158 }
159
160 const amdgpu::Chipset chipset;
161};
162
163struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
164 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
165
166 /// Lowers a shuffle to the corresponding ROCDL ops.
167 ///
168 /// Use the `width` argument to see if src lane is participating.
169 /// If not the dstLane would be itself.
170 ///
171 /// Shuffle with DS Bpermute:
172 /// let shflMode = [xor, up, down, idx]
173 /// let width = 32(usually warpsize), step = [1, 2, 4, 8, 16, ... , width].
174 /// 1. curLaneId = using mbcnt.lo + mbcnt.hi
175 /// 2. widthOrZeroIfOutside = (curLaneId + width) & -width
176 /// 3. dstLane = shflMode(curLaneId, step)
177 /// 4. isActiveSrcLane = dstLane < isActiveSrcLane
178 /// 5. dstLane = isActiveSrcLane ? dstLane : curLaneId
179 /// 6. dwordAlignedDstLane = dstLane * 4 or dstLane << 2.
180 /// 7. bpermute(dwordAlignedDstLane, shfl_value).
181 ///
182 LogicalResult
183 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter) const override {
185 Location loc = op->getLoc();
186 Value initShflValue = adaptor.getValue();
187
188 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
189 Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
190
191 auto int32Type = IntegerType::get(context: rewriter.getContext(), width: 32);
192 Value width = adaptor.getWidth();
193 Value zero = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int32Type, args: 0);
194 Value negwidth = rewriter.create<LLVM::SubOp>(location: loc, args&: int32Type, args&: zero, args&: width);
195 Value add = rewriter.create<LLVM::AddOp>(location: loc, args&: int32Type, args&: srcLaneId, args&: width);
196 Value widthOrZeroIfOutside =
197 rewriter.create<LLVM::AndOp>(location: loc, args&: int32Type, args&: add, args&: negwidth);
198 Value dstLane;
199
200 switch (op.getMode()) {
201 case gpu::ShuffleMode::UP:
202 dstLane = rewriter.create<LLVM::SubOp>(location: loc, args&: int32Type, args&: srcLaneId,
203 args: adaptor.getOffset());
204 break;
205 case gpu::ShuffleMode::DOWN:
206 dstLane = rewriter.create<LLVM::AddOp>(location: loc, args&: int32Type, args&: srcLaneId,
207 args: adaptor.getOffset());
208 break;
209 case gpu::ShuffleMode::XOR:
210 dstLane = rewriter.create<LLVM::XOrOp>(location: loc, args&: int32Type, args&: srcLaneId,
211 args: adaptor.getOffset());
212 break;
213 case gpu::ShuffleMode::IDX:
214 dstLane = adaptor.getOffset();
215 break;
216 }
217 Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
218 location: loc, args: LLVM::ICmpPredicate::slt, args&: dstLane, args&: widthOrZeroIfOutside);
219 Value selectDstLane = rewriter.create<LLVM::SelectOp>(location: loc, args&: isActiveSrcLane,
220 args&: dstLane, args&: srcLaneId);
221 Value two = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int32Type, args: 2);
222 Value dwordAlignedDstLane =
223 rewriter.create<LLVM::ShlOp>(location: loc, args&: int32Type, args&: selectDstLane, args&: two);
224
225 SmallVector<Value> decomposed =
226 LLVM::decomposeValue(builder&: rewriter, loc, src: initShflValue, dstType: int32Type);
227 SmallVector<Value> swizzled;
228 for (Value v : decomposed) {
229 Value res = rewriter.create<ROCDL::DsBpermuteOp>(location: loc, args&: int32Type,
230 args&: dwordAlignedDstLane, args&: v);
231 swizzled.emplace_back(Args&: res);
232 }
233 Value shflValue =
234 LLVM::composeValue(builder&: rewriter, loc, src: swizzled, dstType: initShflValue.getType());
235 rewriter.replaceOp(op, newValues: {shflValue, isActiveSrcLane});
236 return success();
237 }
238};
239
240/// Import the GPU Ops to ROCDL Patterns.
241#include "GPUToROCDL.cpp.inc"
242
243// A pass that replaces all occurrences of GPU device operations with their
244// corresponding ROCDL equivalent.
245//
246// This pass only handles device code and is not meant to be run on GPU host
247// code.
248struct LowerGpuOpsToROCDLOpsPass final
249 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
250 LowerGpuOpsToROCDLOpsPass() = default;
251 LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
252 bool useBarePtrCallConv,
253 gpu::amd::Runtime runtime) {
254 if (this->chipset.getNumOccurrences() == 0)
255 this->chipset = chipset;
256 if (this->indexBitwidth.getNumOccurrences() == 0)
257 this->indexBitwidth = indexBitwidth;
258 if (this->useBarePtrCallConv.getNumOccurrences() == 0)
259 this->useBarePtrCallConv = useBarePtrCallConv;
260 if (this->runtime.getNumOccurrences() == 0)
261 this->runtime = runtime;
262 }
263
264 void getDependentDialects(DialectRegistry &registry) const override {
265 Base::getDependentDialects(registry);
266 registerConvertToLLVMDependentDialectLoading(registry);
267 }
268
269 void runOnOperation() override {
270 gpu::GPUModuleOp m = getOperation();
271 MLIRContext *ctx = m.getContext();
272
273 auto llvmDataLayout = m->getAttrOfType<StringAttr>(
274 name: LLVM::LLVMDialect::getDataLayoutAttrName());
275 if (!llvmDataLayout) {
276 llvmDataLayout = StringAttr::get(context: ctx, bytes: amdgcnDataLayout);
277 m->setAttr(name: LLVM::LLVMDialect::getDataLayoutAttrName(), value: llvmDataLayout);
278 }
279 // Request C wrapper emission.
280 for (auto func : m.getOps<func::FuncOp>()) {
281 func->setAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName(),
282 value: UnitAttr::get(context: ctx));
283 }
284
285 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(name: chipset);
286 if (failed(Result: maybeChipset)) {
287 emitError(loc: UnknownLoc::get(context: ctx), message: "Invalid chipset name: " + chipset);
288 return signalPassFailure();
289 }
290
291 /// Customize the bitwidth used for the device side index computations.
292 LowerToLLVMOptions options(
293 ctx, DataLayout(cast<DataLayoutOpInterface>(Val: m.getOperation())));
294 options.dataLayout = llvm::DataLayout(llvmDataLayout.getValue());
295 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
296 options.overrideIndexBitwidth(bitwidth: indexBitwidth);
297
298 if (useBarePtrCallConv) {
299 options.useBarePtrCallConv = true;
300 WalkResult canUseBarePointers =
301 m.walk(callback: [](gpu::GPUFuncOp func) -> WalkResult {
302 if (canBeCalledWithBarePointers(func))
303 return WalkResult::advance();
304 return WalkResult::interrupt();
305 });
306 if (canUseBarePointers.wasInterrupted()) {
307 emitError(loc: UnknownLoc::get(context: ctx),
308 message: "bare pointer calling convention requires all memrefs to "
309 "have static shape and use the identity map");
310 return signalPassFailure();
311 }
312 }
313
314 // Apply in-dialect lowering. In-dialect lowering will replace
315 // ops which need to be lowered further, which is not supported by a
316 // single conversion pass.
317 {
318 RewritePatternSet patterns(ctx);
319 populateGpuRewritePatterns(patterns);
320 populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
321 (void)applyPatternsGreedily(op: m, patterns: std::move(patterns));
322 }
323
324 LLVMTypeConverter converter(ctx, options);
325 populateGpuMemorySpaceAttributeConversions(
326 typeConverter&: converter, mapping: [](gpu::AddressSpace space) {
327 switch (space) {
328 case gpu::AddressSpace::Global:
329 return 1;
330 case gpu::AddressSpace::Workgroup:
331 return 3;
332 case gpu::AddressSpace::Private:
333 return 5;
334 }
335 llvm_unreachable("unknown address space enum value");
336 return 0;
337 });
338
339 RewritePatternSet llvmPatterns(ctx);
340 LLVMConversionTarget target(getContext());
341
342 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
343 allowedDialects.end());
344 for (Dialect *dialect : ctx->getLoadedDialects()) {
345 bool allowed = allowedDialectsSet.contains(V: dialect->getNamespace());
346 // Empty `allowedDialectsSet` means all dialects are allowed.
347 if (!allowedDialectsSet.empty() && !allowed)
348 continue;
349
350 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
351 if (!iface) {
352 // Error out if dialect was explicily specified but doesn't implement
353 // conversion interface.
354 if (allowed) {
355 m.emitError()
356 << "dialect does not implement ConvertToLLVMPatternInterface: "
357 << dialect->getNamespace();
358 return signalPassFailure();
359 }
360 continue;
361 }
362
363 iface->populateConvertToLLVMConversionPatterns(target, typeConverter&: converter,
364 patterns&: llvmPatterns);
365 }
366
367 populateAMDGPUToROCDLConversionPatterns(converter, patterns&: llvmPatterns,
368 chipset: *maybeChipset);
369 populateGpuToROCDLConversionPatterns(converter, patterns&: llvmPatterns, runtime,
370 chipset: *maybeChipset);
371 configureGpuToROCDLConversionLegality(target);
372 if (failed(Result: applyPartialConversion(op: m, target, patterns: std::move(llvmPatterns))))
373 signalPassFailure();
374 auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
375 auto reqdWorkGroupSizeAttrHelper =
376 rocdlDialect->getReqdWorkGroupSizeAttrHelper();
377 auto flatWorkGroupSizeAttrHelper =
378 rocdlDialect->getFlatWorkGroupSizeAttrHelper();
379 // Manually rewrite known block size attributes so the LLVMIR translation
380 // infrastructure can pick them up.
381 m.walk(callback: [&](LLVM::LLVMFuncOp op) {
382 if (reqdWorkGroupSizeAttrHelper.isAttrPresent(op)) {
383 auto blockSizes = reqdWorkGroupSizeAttrHelper.getAttr(op);
384 // Also set up the rocdl.flat_work_group_size attribute to prevent
385 // conflicting metadata.
386 uint32_t flatSize = 1;
387 for (uint32_t size : blockSizes.asArrayRef()) {
388 flatSize *= size;
389 }
390 StringAttr flatSizeAttr =
391 StringAttr::get(context: ctx, bytes: Twine(flatSize) + "," + Twine(flatSize));
392 flatWorkGroupSizeAttrHelper.setAttr(op, val: flatSizeAttr);
393 }
394 });
395 }
396};
397
398} // namespace
399
400void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
401 target.addIllegalOp<func::FuncOp>();
402 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
403 target.addLegalDialect<ROCDL::ROCDLDialect>();
404 target.addIllegalDialect<gpu::GPUDialect>();
405 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
406 LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
407 LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
408 // These ops are legal for f32 type.
409 target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>(callback: [](Operation *op) {
410 return any_of(Range: op->getOperandTypes(), P: llvm::IsaPred<Float32Type>);
411 });
412 // TODO: Remove once we support replacing non-root ops.
413 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
414}
415
416void mlir::populateGpuToROCDLConversionPatterns(
417 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
418 mlir::gpu::amd::Runtime runtime, amdgpu::Chipset chipset) {
419 using gpu::index_lowering::IndexKind;
420 using gpu::index_lowering::IntrType;
421 using mlir::gpu::amd::Runtime;
422 auto *rocdlDialect =
423 converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
424 populateWithGenerated(patterns);
425 patterns.add<
426 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
427 ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>>(
428 arg: converter, args: IndexKind::Block, args: IntrType::Id);
429 patterns.add<gpu::index_lowering::OpLowering<
430 gpu::BlockIdOp, ROCDL::BlockIdXOp, ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>>(
431 arg: converter, args: IndexKind::Grid, args: IntrType::Id);
432 patterns.add<
433 gpu::index_lowering::OpLowering<gpu::BlockDimOp, ROCDL::BlockDimXOp,
434 ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(
435 arg: converter, args: IndexKind::Block, args: IntrType::Dim);
436 patterns.add<gpu::index_lowering::OpLowering<
437 gpu::GridDimOp, ROCDL::GridDimXOp, ROCDL::GridDimYOp, ROCDL::GridDimZOp>>(
438 arg: converter, args: IndexKind::Grid, args: IntrType::Dim);
439 patterns.add<GPUReturnOpLowering>(arg: converter);
440 patterns.add<GPUFuncOpLowering>(
441 arg: converter,
442 args: GPUFuncOpLoweringOptions{
443 /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
444 /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
445 .kernelAttributeName: rocdlDialect->getKernelAttrHelper().getName(),
446 .kernelBlockSizeAttributeName: rocdlDialect->getReqdWorkGroupSizeAttrHelper().getName()});
447 if (Runtime::HIP == runtime) {
448 patterns.add<GPUPrintfOpToHIPLowering>(arg: converter);
449 } else if (Runtime::OpenCL == runtime) {
450 // Use address space = 4 to match the OpenCL definition of printf()
451 patterns.add<GPUPrintfOpToLLVMCallLowering>(arg: converter, /*addressSpace=*/args: 4);
452 }
453 // TODO: Add alignment for workgroup memory
454 patterns.add<GPUDynamicSharedMemoryOpLowering>(arg: converter);
455
456 patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(arg: converter);
457 patterns.add<GPUSubgroupSizeOpToROCDL>(arg: converter, args&: chipset);
458
459 populateMathToROCDLConversionPatterns(converter, patterns);
460}
461
462std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
463mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
464 unsigned indexBitwidth,
465 bool useBarePtrCallConv,
466 gpu::amd::Runtime runtime) {
467 return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
468 args: chipset, args&: indexBitwidth, args&: useBarePtrCallConv, args&: runtime);
469}
470

source code of mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp