1 | //===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements in-dialect rewriting of the shuffle op for types i64 and |
10 | // f64, rewriting 64bit shuffles into two 32bit shuffles. This particular |
11 | // implementation using shifts and truncations can be obtained using clang: by |
12 | // emitting IR for shuffle operations with `-O3`. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
18 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/PatternMatch.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | |
23 | using namespace mlir; |
24 | |
25 | namespace { |
26 | struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> { |
27 | using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern; |
28 | |
29 | void initialize() { |
30 | // Required as the pattern will replace the Op with 2 additional ShuffleOps. |
31 | setHasBoundedRewriteRecursion(); |
32 | } |
33 | LogicalResult matchAndRewrite(gpu::ShuffleOp op, |
34 | PatternRewriter &rewriter) const override { |
35 | auto loc = op.getLoc(); |
36 | auto value = op.getValue(); |
37 | auto valueType = value.getType(); |
38 | auto valueLoc = value.getLoc(); |
39 | auto i32 = rewriter.getI32Type(); |
40 | auto i64 = rewriter.getI64Type(); |
41 | |
42 | // If the type of the value is either i32 or f32, the op is already valid. |
43 | if (valueType.getIntOrFloatBitWidth() == 32) |
44 | return failure(); |
45 | |
46 | Value lo, hi; |
47 | |
48 | // Float types must be converted to i64 to extract the bits. |
49 | if (isa<FloatType>(valueType)) |
50 | value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value); |
51 | |
52 | // Get the low bits by trunc(value). |
53 | lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value); |
54 | |
55 | // Get the high bits by trunc(value >> 32). |
56 | auto c32 = rewriter.create<arith::ConstantOp>( |
57 | valueLoc, rewriter.getIntegerAttr(i64, 32)); |
58 | hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32); |
59 | hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi); |
60 | |
61 | // Shuffle the values. |
62 | ValueRange loRes = |
63 | rewriter |
64 | .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(), |
65 | op.getWidth(), op.getMode()) |
66 | .getResults(); |
67 | ValueRange hiRes = |
68 | rewriter |
69 | .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(), |
70 | op.getWidth(), op.getMode()) |
71 | .getResults(); |
72 | |
73 | // Convert lo back to i64. |
74 | lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]); |
75 | |
76 | // Convert hi back to i64. |
77 | hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]); |
78 | hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32); |
79 | |
80 | // Obtain the shuffled bits hi | lo. |
81 | value = rewriter.create<arith::OrIOp>(loc, hi, lo); |
82 | |
83 | // Convert the value back to float. |
84 | if (isa<FloatType>(valueType)) |
85 | value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value); |
86 | |
87 | // Obtain the shuffle validity by combining both validities. |
88 | auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]); |
89 | |
90 | // Replace the op. |
91 | rewriter.replaceOp(op, {value, validity}); |
92 | return success(); |
93 | } |
94 | }; |
95 | } // namespace |
96 | |
97 | void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) { |
98 | patterns.add<GpuShuffleRewriter>(arg: patterns.getContext()); |
99 | } |
100 | |