1//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect rewriting of the shuffle op for types i64 and
10// f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
11// implementation using shifts and truncations can be obtained using clang: by
12// emitting IR for shuffle operations with `-O3`.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/GPU/Transforms/Passes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/PatternMatch.h"
21
22using namespace mlir;
23
24namespace {
25struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
26 using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
27
28 void initialize() {
29 // Required as the pattern will replace the Op with 2 additional ShuffleOps.
30 setHasBoundedRewriteRecursion();
31 }
32 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
33 PatternRewriter &rewriter) const override {
34 auto loc = op.getLoc();
35 auto value = op.getValue();
36 auto valueType = value.getType();
37 auto valueLoc = value.getLoc();
38 auto i32 = rewriter.getI32Type();
39 auto i64 = rewriter.getI64Type();
40
41 // If the type of the value is either i32 or f32, the op is already valid.
42 if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
43 return rewriter.notifyMatchFailure(
44 arg&: op, msg: "only 64-bit int/float types are supported");
45
46 Value lo, hi;
47
48 // Float types must be converted to i64 to extract the bits.
49 if (isa<FloatType>(Val: valueType))
50 value = rewriter.create<arith::BitcastOp>(location: valueLoc, args&: i64, args&: value);
51
52 // Get the low bits by trunc(value).
53 lo = rewriter.create<arith::TruncIOp>(location: valueLoc, args&: i32, args&: value);
54
55 // Get the high bits by trunc(value >> 32).
56 auto c32 = rewriter.create<arith::ConstantOp>(
57 location: valueLoc, args: rewriter.getIntegerAttr(type: i64, value: 32));
58 hi = rewriter.create<arith::ShRUIOp>(location: valueLoc, args&: value, args&: c32);
59 hi = rewriter.create<arith::TruncIOp>(location: valueLoc, args&: i32, args&: hi);
60
61 // Shuffle the values.
62 ValueRange loRes =
63 rewriter
64 .create<gpu::ShuffleOp>(location: op.getLoc(), args&: lo, args: op.getOffset(),
65 args: op.getWidth(), args: op.getMode())
66 .getResults();
67 ValueRange hiRes =
68 rewriter
69 .create<gpu::ShuffleOp>(location: op.getLoc(), args&: hi, args: op.getOffset(),
70 args: op.getWidth(), args: op.getMode())
71 .getResults();
72
73 // Convert lo back to i64.
74 lo = rewriter.create<arith::ExtUIOp>(location: valueLoc, args&: i64, args: loRes[0]);
75
76 // Convert hi back to i64.
77 hi = rewriter.create<arith::ExtUIOp>(location: valueLoc, args&: i64, args: hiRes[0]);
78 hi = rewriter.create<arith::ShLIOp>(location: valueLoc, args&: hi, args&: c32);
79
80 // Obtain the shuffled bits hi | lo.
81 value = rewriter.create<arith::OrIOp>(location: loc, args&: hi, args&: lo);
82
83 // Convert the value back to float.
84 if (isa<FloatType>(Val: valueType))
85 value = rewriter.create<arith::BitcastOp>(location: valueLoc, args&: valueType, args&: value);
86
87 // Obtain the shuffle validity by combining both validities.
88 auto validity = rewriter.create<arith::AndIOp>(location: loc, args: loRes[1], args: hiRes[1]);
89
90 // Replace the op.
91 rewriter.replaceOp(op, newValues: {value, validity});
92 return success();
93 }
94};
95} // namespace
96
97void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) {
98 patterns.add<GpuShuffleRewriter>(arg: patterns.getContext());
99}
100

source code of mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp