1//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect rewriting of the shuffle op for types i64 and
10// f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
11// implementation using shifts and truncations can be obtained using clang: by
12// emitting IR for shuffle operations with `-O3`.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/GPU/Transforms/Passes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Pass/Pass.h"
22
23using namespace mlir;
24
25namespace {
26struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
27 using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
28
29 void initialize() {
30 // Required as the pattern will replace the Op with 2 additional ShuffleOps.
31 setHasBoundedRewriteRecursion();
32 }
33 LogicalResult matchAndRewrite(gpu::ShuffleOp op,
34 PatternRewriter &rewriter) const override {
35 auto loc = op.getLoc();
36 auto value = op.getValue();
37 auto valueType = value.getType();
38 auto valueLoc = value.getLoc();
39 auto i32 = rewriter.getI32Type();
40 auto i64 = rewriter.getI64Type();
41
42 // If the type of the value is either i32 or f32, the op is already valid.
43 if (valueType.getIntOrFloatBitWidth() == 32)
44 return failure();
45
46 Value lo, hi;
47
48 // Float types must be converted to i64 to extract the bits.
49 if (isa<FloatType>(valueType))
50 value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value);
51
52 // Get the low bits by trunc(value).
53 lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value);
54
55 // Get the high bits by trunc(value >> 32).
56 auto c32 = rewriter.create<arith::ConstantOp>(
57 valueLoc, rewriter.getIntegerAttr(i64, 32));
58 hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32);
59 hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi);
60
61 // Shuffle the values.
62 ValueRange loRes =
63 rewriter
64 .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
65 op.getWidth(), op.getMode())
66 .getResults();
67 ValueRange hiRes =
68 rewriter
69 .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
70 op.getWidth(), op.getMode())
71 .getResults();
72
73 // Convert lo back to i64.
74 lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]);
75
76 // Convert hi back to i64.
77 hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]);
78 hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32);
79
80 // Obtain the shuffled bits hi | lo.
81 value = rewriter.create<arith::OrIOp>(loc, hi, lo);
82
83 // Convert the value back to float.
84 if (isa<FloatType>(valueType))
85 value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value);
86
87 // Obtain the shuffle validity by combining both validities.
88 auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]);
89
90 // Replace the op.
91 rewriter.replaceOp(op, {value, validity});
92 return success();
93 }
94};
95} // namespace
96
97void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) {
98 patterns.add<GpuShuffleRewriter>(arg: patterns.getContext());
99}
100

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp