1//===- PromoteShuffleToAMDGPU.cpp - Promote shuffle to AMDGPU -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns to try to promote `gpu.shuffle`s to specialized
10// AMDGPU intrinsics.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/GPU/Transforms/Passes.h"
15
16#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/IR/PatternMatch.h"
20
21using namespace mlir;
22
23namespace {
24/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
25/// and offset must be a constant integer in the range [0, 31].
26struct PromoteShuffleToSwizzlePattern
27 : public OpRewritePattern<gpu::ShuffleOp> {
28 using OpRewritePattern::OpRewritePattern;
29
30 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
31 PatternRewriter &rewriter) const override {
32 if (op.getMode() != gpu::ShuffleMode::XOR)
33 return rewriter.notifyMatchFailure(op,
34 "only xor shuffle mode is supported");
35
36 if (!isConstantIntValue(op.getWidth(), 64))
37 return rewriter.notifyMatchFailure(op,
38 "only 64 width shuffle is supported");
39
40 std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
41 if (!offset)
42 return rewriter.notifyMatchFailure(op,
43 "offset must be a constant integer");
44
45 int64_t offsetValue = *offset;
46 if (offsetValue < 0 || offsetValue >= 32)
47 return rewriter.notifyMatchFailure(op,
48 "offset must be in the range [0, 31]");
49
50 Location loc = op.getLoc();
51 Value res = rewriter.create<amdgpu::SwizzleBitModeOp>(
52 loc, op.getResult(0).getType(), op.getValue(), /*andMask=*/31,
53 /*orMask=*/0, /*xorMask=*/offsetValue);
54 Value valid = rewriter.create<arith::ConstantIntOp>(location: loc, args: 1, /*width*/ args: 1);
55 rewriter.replaceOp(op, {res, valid});
56 return success();
57 }
58};
59} // namespace
60
61void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
62 RewritePatternSet &patterns) {
63 patterns.add<PromoteShuffleToSwizzlePattern>(arg: patterns.getContext());
64}
65

source code of mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp