1//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements gradual lowering of `gpu.subgroup_reduce` ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/GPU/Transforms/Passes.h"
16#include "mlir/Dialect/GPU/Transforms/Utils.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Location.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/IR/TypeUtilities.h"
22#include "mlir/Support/LogicalResult.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
25#include <cassert>
26#include <cstdint>
27
28using namespace mlir;
29
30namespace {
31
32/// Example, assumes `maxShuffleBitwidth` equal to 32:
33/// ```
34/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
35/// ==>
36/// %v0 = arith.constant dense<0.0> : vector<3xf16>
37/// %e0 = vector.extract_strided_slice %x
38/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
39/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
40/// %v1 = vector.insert_strided_slice %r0, %v0
41/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
42/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
43/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
44/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
45/// ```
46struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
47 BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
48 PatternBenefit benefit)
49 : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
50 }
51
52 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
53 PatternRewriter &rewriter) const override {
54 auto vecTy = dyn_cast<VectorType>(op.getType());
55 if (!vecTy || vecTy.getNumElements() < 2)
56 return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
57
58 assert(vecTy.getRank() == 1 && "Unexpected vector type");
59 assert(!vecTy.isScalable() && "Unexpected vector type");
60
61 Type elemTy = vecTy.getElementType();
62 unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
63 if (elemBitwidth >= maxShuffleBitwidth)
64 return rewriter.notifyMatchFailure(
65 op, llvm::formatv(Fmt: "element type too large ({0}), cannot break down "
66 "into vectors of bitwidth {1} or less",
67 Vals&: elemBitwidth, Vals: maxShuffleBitwidth));
68
69 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
70 assert(elementsPerShuffle >= 1);
71
72 unsigned numNewReductions =
73 llvm::divideCeil(Numerator: vecTy.getNumElements(), Denominator: elementsPerShuffle);
74 assert(numNewReductions >= 1);
75 if (numNewReductions == 1)
76 return rewriter.notifyMatchFailure(op, "nothing to break down");
77
78 Location loc = op.getLoc();
79 Value res =
80 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
81
82 for (unsigned i = 0; i != numNewReductions; ++i) {
83 int64_t startIdx = i * elementsPerShuffle;
84 int64_t endIdx =
85 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
86 int64_t numElems = endIdx - startIdx;
87
88 Value extracted;
89 if (numElems == 1) {
90 extracted =
91 rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
92 } else {
93 extracted = rewriter.create<vector::ExtractStridedSliceOp>(
94 loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
95 /*strides=*/1);
96 }
97
98 Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
99 loc, extracted, op.getOp(), op.getUniform());
100 if (numElems == 1) {
101 res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
102 continue;
103 }
104
105 res = rewriter.create<vector::InsertStridedSliceOp>(
106 loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
107 }
108
109 rewriter.replaceOp(op, res);
110 return success();
111 }
112
113private:
114 unsigned maxShuffleBitwidth = 0;
115};
116
117/// Example:
118/// ```
119/// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
120/// ==>
121/// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
122/// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
123/// %a = vector.broadcast %r0 : f32 to vector<1xf32>
124/// ```
125struct ScalarizeSingleElementReduce final
126 : OpRewritePattern<gpu::SubgroupReduceOp> {
127 using OpRewritePattern::OpRewritePattern;
128
129 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
130 PatternRewriter &rewriter) const override {
131 auto vecTy = dyn_cast<VectorType>(op.getType());
132 if (!vecTy || vecTy.getNumElements() != 1)
133 return rewriter.notifyMatchFailure(op, "not a single-element reduction");
134
135 assert(vecTy.getRank() == 1 && "Unexpected vector type");
136 assert(!vecTy.isScalable() && "Unexpected vector type");
137 Location loc = op.getLoc();
138 Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
139 Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
140 loc, extracted, op.getOp(), op.getUniform());
141 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
142 return success();
143 }
144};
145
146/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
147/// and `unpackFn` to convert to the native shuffle type and to the reduction
148/// type, respectively. For example, with `input` of type `f16`, `packFn` could
149/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
150/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
151/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
152static Value createSubgroupShuffleReduction(
153 OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
154 unsigned subgroupSize, function_ref<Value(Value)> packFn,
155 function_ref<Value(Value)> unpackFn) {
156 assert(llvm::isPowerOf2_32(subgroupSize));
157 // Lane value always stays in the original type. We use it to perform arith
158 // reductions.
159 Value laneVal = input;
160 // Parallel reduction using butterfly shuffles.
161 for (unsigned i = 1; i < subgroupSize; i <<= 1) {
162 Value shuffled = builder
163 .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
164 /*width=*/subgroupSize,
165 /*mode=*/gpu::ShuffleMode::XOR)
166 .getShuffleResult();
167 laneVal = vector::makeArithReduction(builder, loc,
168 gpu::convertReductionKind(mode),
169 laneVal, unpackFn(shuffled));
170 assert(laneVal.getType() == input.getType());
171 }
172
173 return laneVal;
174}
175
176/// Lowers scalar gpu subgroup reductions to a series of shuffles.
177struct ScalarSubgroupReduceToShuffles final
178 : OpRewritePattern<gpu::SubgroupReduceOp> {
179 ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
180 unsigned shuffleBitwidth,
181 PatternBenefit benefit)
182 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
183 shuffleBitwidth(shuffleBitwidth) {}
184
185 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
186 PatternRewriter &rewriter) const override {
187 Type valueTy = op.getType();
188 unsigned elemBitwidth =
189 getElementTypeOrSelf(type: valueTy).getIntOrFloatBitWidth();
190 if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
191 return rewriter.notifyMatchFailure(
192 op, "value type is not a compatible scalar");
193
194 Location loc = op.getLoc();
195 // Since this is already a native shuffle scalar, no packing is necessary.
196 if (elemBitwidth == shuffleBitwidth) {
197 auto identityFn = [](Value v) { return v; };
198 rewriter.replaceOp(op, createSubgroupShuffleReduction(
199 rewriter, loc, op.getValue(), op.getOp(),
200 subgroupSize, identityFn, identityFn));
201 return success();
202 }
203
204 auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
205 auto equivIntType = rewriter.getIntegerType(elemBitwidth);
206 auto packFn = [loc, &rewriter, equivIntType,
207 shuffleIntType](Value unpackedVal) -> Value {
208 auto asInt =
209 rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
210 return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
211 };
212 auto unpackFn = [loc, &rewriter, equivIntType,
213 valueTy](Value packedVal) -> Value {
214 auto asInt =
215 rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
216 return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
217 };
218
219 rewriter.replaceOp(op, createSubgroupShuffleReduction(
220 rewriter, loc, op.getValue(), op.getOp(),
221 subgroupSize, packFn, unpackFn));
222 return success();
223 }
224
225private:
226 unsigned subgroupSize = 0;
227 unsigned shuffleBitwidth = 0;
228};
229
230/// Lowers vector gpu subgroup reductions to a series of shuffles.
231struct VectorSubgroupReduceToShuffles final
232 : OpRewritePattern<gpu::SubgroupReduceOp> {
233 VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
234 unsigned shuffleBitwidth,
235 PatternBenefit benefit)
236 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
237 shuffleBitwidth(shuffleBitwidth) {}
238
239 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
240 PatternRewriter &rewriter) const override {
241 auto vecTy = dyn_cast<VectorType>(op.getType());
242 if (!vecTy)
243 return rewriter.notifyMatchFailure(op, "value type is not a vector");
244
245 unsigned vecBitwidth =
246 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
247 if (vecBitwidth > shuffleBitwidth)
248 return rewriter.notifyMatchFailure(
249 op,
250 llvm::formatv(Fmt: "vector type bitwidth too large ({0}), cannot lower "
251 "to shuffles of size {1}",
252 Vals&: vecBitwidth, Vals: shuffleBitwidth));
253
254 unsigned elementsPerShuffle =
255 shuffleBitwidth / vecTy.getElementTypeBitWidth();
256 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
257 return rewriter.notifyMatchFailure(
258 op, "shuffle bitwidth is not a multiple of the element bitwidth");
259
260 Location loc = op.getLoc();
261
262 // If the reduced type is smaller than the native shuffle size, extend it,
263 // perform the shuffles, and extract at the end.
264 auto extendedVecTy = VectorType::get(
265 static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
266 Value extendedInput = op.getValue();
267 if (vecBitwidth < shuffleBitwidth) {
268 auto zero = rewriter.create<arith::ConstantOp>(
269 loc, rewriter.getZeroAttr(extendedVecTy));
270 extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
271 loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
272 }
273
274 auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
275 auto shuffleVecType = VectorType::get(1, shuffleIntType);
276
277 auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
278 auto asIntVec =
279 rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
280 return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
281 };
282 auto unpackFn = [loc, &rewriter, shuffleVecType,
283 extendedVecTy](Value packedVal) -> Value {
284 auto asIntVec =
285 rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
286 return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
287 };
288
289 Value res =
290 createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
291 subgroupSize, packFn, unpackFn);
292
293 if (vecBitwidth < shuffleBitwidth) {
294 res = rewriter.create<vector::ExtractStridedSliceOp>(
295 loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
296 /*strides=*/1);
297 }
298
299 rewriter.replaceOp(op, res);
300 return success();
301 }
302
303private:
304 unsigned subgroupSize = 0;
305 unsigned shuffleBitwidth = 0;
306};
307} // namespace
308
309void mlir::populateGpuBreakDownSubgrupReducePatterns(
310 RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
311 PatternBenefit benefit) {
312 patterns.add<BreakDownSubgroupReduce>(arg: patterns.getContext(),
313 args&: maxShuffleBitwidth, args&: benefit);
314 patterns.add<ScalarizeSingleElementReduce>(arg: patterns.getContext(), args&: benefit);
315}
316
317void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs(
318 RewritePatternSet &patterns, unsigned subgroupSize,
319 unsigned shuffleBitwidth, PatternBenefit benefit) {
320 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
321 arg: patterns.getContext(), args&: subgroupSize, args&: shuffleBitwidth, args&: benefit);
322}
323

source code of mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp