1//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect rewriting of the shuffle op for types i64 and
10// f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
11// implementation using shifts and truncations can be obtained using clang: by
12// emitting IR for shuffle operations with `-O3`.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/GPU/Transforms/Passes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Pass/Pass.h"
22
23using namespace mlir;
24
25namespace {
26struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
27 using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
28
29 void initialize() {
30 // Required as the pattern will replace the Op with 2 additional ShuffleOps.
31 setHasBoundedRewriteRecursion();
32 }
33 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
34 PatternRewriter &rewriter) const override {
35 auto loc = op.getLoc();
36 auto value = op.getValue();
37 auto valueType = value.getType();
38 auto valueLoc = value.getLoc();
39 auto i32 = rewriter.getI32Type();
40 auto i64 = rewriter.getI64Type();
41
42 // If the type of the value is either i32 or f32, the op is already valid.
43 if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
44 return rewriter.notifyMatchFailure(
45 op, "only 64-bit int/float types are supported");
46
47 Value lo, hi;
48
49 // Float types must be converted to i64 to extract the bits.
50 if (isa<FloatType>(valueType))
51 value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value);
52
53 // Get the low bits by trunc(value).
54 lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value);
55
56 // Get the high bits by trunc(value >> 32).
57 auto c32 = rewriter.create<arith::ConstantOp>(
58 valueLoc, rewriter.getIntegerAttr(i64, 32));
59 hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32);
60 hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi);
61
62 // Shuffle the values.
63 ValueRange loRes =
64 rewriter
65 .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
66 op.getWidth(), op.getMode())
67 .getResults();
68 ValueRange hiRes =
69 rewriter
70 .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
71 op.getWidth(), op.getMode())
72 .getResults();
73
74 // Convert lo back to i64.
75 lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]);
76
77 // Convert hi back to i64.
78 hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]);
79 hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32);
80
81 // Obtain the shuffled bits hi | lo.
82 value = rewriter.create<arith::OrIOp>(loc, hi, lo);
83
84 // Convert the value back to float.
85 if (isa<FloatType>(valueType))
86 value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value);
87
88 // Obtain the shuffle validity by combining both validities.
89 auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]);
90
91 // Replace the op.
92 rewriter.replaceOp(op, {value, validity});
93 return success();
94 }
95};
96} // namespace
97
98void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) {
99 patterns.add<GpuShuffleRewriter>(arg: patterns.getContext());
100}
101

source code of mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp