| 1 | //===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements in-dialect rewriting of the shuffle op for types i64 and |
| 10 | // f64, rewriting 64bit shuffles into two 32bit shuffles. This particular |
| 11 | // implementation using shifts and truncations can be obtained using clang: by |
| 12 | // emitting IR for shuffle operations with `-O3`. |
| 13 | // |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 18 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 19 | #include "mlir/IR/Builders.h" |
| 20 | #include "mlir/IR/PatternMatch.h" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | |
| 24 | namespace { |
| 25 | struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> { |
| 26 | using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern; |
| 27 | |
| 28 | void initialize() { |
| 29 | // Required as the pattern will replace the Op with 2 additional ShuffleOps. |
| 30 | setHasBoundedRewriteRecursion(); |
| 31 | } |
| 32 | LogicalResult matchAndRewrite(gpu::ShuffleOp op, |
| 33 | PatternRewriter &rewriter) const override { |
| 34 | auto loc = op.getLoc(); |
| 35 | auto value = op.getValue(); |
| 36 | auto valueType = value.getType(); |
| 37 | auto valueLoc = value.getLoc(); |
| 38 | auto i32 = rewriter.getI32Type(); |
| 39 | auto i64 = rewriter.getI64Type(); |
| 40 | |
| 41 | // If the type of the value is either i32 or f32, the op is already valid. |
| 42 | if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64) |
| 43 | return rewriter.notifyMatchFailure( |
| 44 | arg&: op, msg: "only 64-bit int/float types are supported" ); |
| 45 | |
| 46 | Value lo, hi; |
| 47 | |
| 48 | // Float types must be converted to i64 to extract the bits. |
| 49 | if (isa<FloatType>(Val: valueType)) |
| 50 | value = rewriter.create<arith::BitcastOp>(location: valueLoc, args&: i64, args&: value); |
| 51 | |
| 52 | // Get the low bits by trunc(value). |
| 53 | lo = rewriter.create<arith::TruncIOp>(location: valueLoc, args&: i32, args&: value); |
| 54 | |
| 55 | // Get the high bits by trunc(value >> 32). |
| 56 | auto c32 = rewriter.create<arith::ConstantOp>( |
| 57 | location: valueLoc, args: rewriter.getIntegerAttr(type: i64, value: 32)); |
| 58 | hi = rewriter.create<arith::ShRUIOp>(location: valueLoc, args&: value, args&: c32); |
| 59 | hi = rewriter.create<arith::TruncIOp>(location: valueLoc, args&: i32, args&: hi); |
| 60 | |
| 61 | // Shuffle the values. |
| 62 | ValueRange loRes = |
| 63 | rewriter |
| 64 | .create<gpu::ShuffleOp>(location: op.getLoc(), args&: lo, args: op.getOffset(), |
| 65 | args: op.getWidth(), args: op.getMode()) |
| 66 | .getResults(); |
| 67 | ValueRange hiRes = |
| 68 | rewriter |
| 69 | .create<gpu::ShuffleOp>(location: op.getLoc(), args&: hi, args: op.getOffset(), |
| 70 | args: op.getWidth(), args: op.getMode()) |
| 71 | .getResults(); |
| 72 | |
| 73 | // Convert lo back to i64. |
| 74 | lo = rewriter.create<arith::ExtUIOp>(location: valueLoc, args&: i64, args: loRes[0]); |
| 75 | |
| 76 | // Convert hi back to i64. |
| 77 | hi = rewriter.create<arith::ExtUIOp>(location: valueLoc, args&: i64, args: hiRes[0]); |
| 78 | hi = rewriter.create<arith::ShLIOp>(location: valueLoc, args&: hi, args&: c32); |
| 79 | |
| 80 | // Obtain the shuffled bits hi | lo. |
| 81 | value = rewriter.create<arith::OrIOp>(location: loc, args&: hi, args&: lo); |
| 82 | |
| 83 | // Convert the value back to float. |
| 84 | if (isa<FloatType>(Val: valueType)) |
| 85 | value = rewriter.create<arith::BitcastOp>(location: valueLoc, args&: valueType, args&: value); |
| 86 | |
| 87 | // Obtain the shuffle validity by combining both validities. |
| 88 | auto validity = rewriter.create<arith::AndIOp>(location: loc, args: loRes[1], args: hiRes[1]); |
| 89 | |
| 90 | // Replace the op. |
| 91 | rewriter.replaceOp(op, newValues: {value, validity}); |
| 92 | return success(); |
| 93 | } |
| 94 | }; |
| 95 | } // namespace |
| 96 | |
| 97 | void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) { |
| 98 | patterns.add<GpuShuffleRewriter>(arg: patterns.getContext()); |
| 99 | } |
| 100 | |