1//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements gradual lowering of `gpu.subgroup_reduce` ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/GPU/Transforms/Passes.h"
18#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
19#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/MathExtras.h"
27#include <cassert>
28#include <cstdint>
29
30using namespace mlir;
31
32namespace {
33
34/// Example, assumes `maxShuffleBitwidth` equal to 32:
35/// ```
36/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
37/// ==>
38/// %v0 = arith.constant dense<0.0> : vector<3xf16>
39/// %e0 = vector.extract_strided_slice %x
40/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
41/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
42/// %v1 = vector.insert_strided_slice %r0, %v0
43/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
44/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
45/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
46/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
47/// ```
48struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
49 BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
50 PatternBenefit benefit)
51 : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
52 }
53
54 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
55 PatternRewriter &rewriter) const override {
56 auto vecTy = dyn_cast<VectorType>(Val: op.getType());
57 if (!vecTy || vecTy.getNumElements() < 2)
58 return rewriter.notifyMatchFailure(arg&: op, msg: "not a multi-element reduction");
59
60 assert(vecTy.getRank() == 1 && "Unexpected vector type");
61 assert(!vecTy.isScalable() && "Unexpected vector type");
62
63 Type elemTy = vecTy.getElementType();
64 unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
65 if (elemBitwidth >= maxShuffleBitwidth)
66 return rewriter.notifyMatchFailure(
67 arg&: op, msg: llvm::formatv(Fmt: "element type too large ({0}), cannot break down "
68 "into vectors of bitwidth {1} or less",
69 Vals&: elemBitwidth, Vals: maxShuffleBitwidth));
70
71 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
72 assert(elementsPerShuffle >= 1);
73
74 unsigned numNewReductions =
75 llvm::divideCeil(Numerator: vecTy.getNumElements(), Denominator: elementsPerShuffle);
76 assert(numNewReductions >= 1);
77 if (numNewReductions == 1)
78 return rewriter.notifyMatchFailure(arg&: op, msg: "nothing to break down");
79
80 Location loc = op.getLoc();
81 Value res =
82 rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getZeroAttr(type: vecTy));
83
84 for (unsigned i = 0; i != numNewReductions; ++i) {
85 int64_t startIdx = i * elementsPerShuffle;
86 int64_t endIdx =
87 std::min(a: startIdx + elementsPerShuffle, b: vecTy.getNumElements());
88 int64_t numElems = endIdx - startIdx;
89
90 Value extracted;
91 if (numElems == 1) {
92 extracted =
93 rewriter.create<vector::ExtractOp>(location: loc, args: op.getValue(), args&: startIdx);
94 } else {
95 extracted = rewriter.create<vector::ExtractStridedSliceOp>(
96 location: loc, args: op.getValue(), /*offsets=*/args&: startIdx, /*sizes=*/args&: numElems,
97 /*strides=*/args: 1);
98 }
99
100 Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
101 location: loc, args&: extracted, args: op.getOp(), args: op.getUniform(), args: op.getClusterSize(),
102 args: op.getClusterStride());
103 if (numElems == 1) {
104 res = rewriter.create<vector::InsertOp>(location: loc, args&: reduce, args&: res, args&: startIdx);
105 continue;
106 }
107
108 res = rewriter.create<vector::InsertStridedSliceOp>(
109 location: loc, args&: reduce, args&: res, /*offsets=*/args&: startIdx, /*strides=*/args: 1);
110 }
111
112 rewriter.replaceOp(op, newValues: res);
113 return success();
114 }
115
116private:
117 unsigned maxShuffleBitwidth = 0;
118};
119
120/// Example:
121/// ```
122/// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
123/// ==>
124/// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
125/// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
126/// %a = vector.broadcast %r0 : f32 to vector<1xf32>
127/// ```
128struct ScalarizeSingleElementReduce final
129 : OpRewritePattern<gpu::SubgroupReduceOp> {
130 using OpRewritePattern::OpRewritePattern;
131
132 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
133 PatternRewriter &rewriter) const override {
134 auto vecTy = dyn_cast<VectorType>(Val: op.getType());
135 if (!vecTy || vecTy.getNumElements() != 1)
136 return rewriter.notifyMatchFailure(arg&: op, msg: "not a single-element reduction");
137
138 assert(vecTy.getRank() == 1 && "Unexpected vector type");
139 assert(!vecTy.isScalable() && "Unexpected vector type");
140 Location loc = op.getLoc();
141 Value extracted = rewriter.create<vector::ExtractOp>(location: loc, args: op.getValue(), args: 0);
142 Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
143 location: loc, args&: extracted, args: op.getOp(), args: op.getUniform(), args: op.getClusterSize(),
144 args: op.getClusterStride());
145 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, args&: vecTy, args&: reduce);
146 return success();
147 }
148};
149
150struct ClusterInfo {
151 unsigned clusterStride;
152 unsigned clusterSize;
153 unsigned subgroupSize;
154};
155
156static FailureOr<ClusterInfo>
157getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
158 assert(llvm::isPowerOf2_32(subgroupSize));
159
160 std::optional<uint32_t> clusterSize = op.getClusterSize();
161 assert(!clusterSize ||
162 llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
163 if (clusterSize && *clusterSize > subgroupSize)
164 return op.emitOpError()
165 << "cluster size " << *clusterSize
166 << " is greater than subgroup size " << subgroupSize;
167 unsigned effectiveClusterSize = clusterSize.value_or(u&: subgroupSize);
168
169 auto clusterStride = op.getClusterStride();
170 assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
171 if (clusterStride >= subgroupSize)
172 return op.emitOpError()
173 << "cluster stride " << clusterStride
174 << " is not less than subgroup size " << subgroupSize;
175
176 return ClusterInfo{.clusterStride: clusterStride, .clusterSize: effectiveClusterSize, .subgroupSize: subgroupSize};
177}
178
179/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
180/// and `unpackFn` to convert to the native shuffle type and to the reduction
181/// type, respectively. For example, with `input` of type `f16`, `packFn` could
182/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
183/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
184/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
185/// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
186/// lanes within a cluster, reducing all lanes in each cluster in parallel.
187Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
188 Value input, gpu::AllReduceOperation mode,
189 const ClusterInfo &ci,
190 function_ref<Value(Value)> packFn,
191 function_ref<Value(Value)> unpackFn) {
192 // Lane value always stays in the original type. We use it to perform arith
193 // reductions.
194 Value laneVal = input;
195 // Parallel reduction using butterfly shuffles.
196 for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
197 i <<= 1) {
198 Value shuffled = builder
199 .create<gpu::ShuffleOp>(location: loc, args: packFn(laneVal), args&: i,
200 /*width=*/args: ci.subgroupSize,
201 /*mode=*/args: gpu::ShuffleMode::XOR)
202 .getShuffleResult();
203 laneVal = vector::makeArithReduction(b&: builder, loc,
204 kind: gpu::convertReductionKind(mode),
205 v1: laneVal, acc: unpackFn(shuffled));
206 assert(laneVal.getType() == input.getType());
207 }
208
209 return laneVal;
210}
211
212/// Lowers scalar gpu subgroup reductions to a series of shuffles.
213struct ScalarSubgroupReduceToShuffles final
214 : OpRewritePattern<gpu::SubgroupReduceOp> {
215 ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
216 unsigned shuffleBitwidth, bool matchClustered,
217 PatternBenefit benefit)
218 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
219 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
220
221 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
222 PatternRewriter &rewriter) const override {
223 if (op.getClusterSize().has_value() != matchClustered) {
224 return rewriter.notifyMatchFailure(
225 arg&: op, msg: llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to "
226 "only match {1}clustered ops",
227 Vals: matchClustered ? "non-" : "",
228 Vals: matchClustered ? "" : "non-"));
229 }
230
231 auto ci = getAndValidateClusterInfo(op, subgroupSize);
232 if (failed(Result: ci))
233 return failure();
234
235 Type valueTy = op.getType();
236 unsigned elemBitwidth =
237 getElementTypeOrSelf(type: valueTy).getIntOrFloatBitWidth();
238 if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
239 return rewriter.notifyMatchFailure(
240 arg&: op, msg: "value type is not a compatible scalar");
241
242 Location loc = op.getLoc();
243 // Since this is already a native shuffle scalar, no packing is necessary.
244 if (elemBitwidth == shuffleBitwidth) {
245 auto identityFn = [](Value v) { return v; };
246 rewriter.replaceOp(op, newValues: createSubgroupShuffleReduction(
247 builder&: rewriter, loc, input: op.getValue(), mode: op.getOp(), ci: *ci,
248 packFn: identityFn, unpackFn: identityFn));
249 return success();
250 }
251
252 auto shuffleIntType = rewriter.getIntegerType(width: shuffleBitwidth);
253 auto equivIntType = rewriter.getIntegerType(width: elemBitwidth);
254 auto packFn = [loc, &rewriter, equivIntType,
255 shuffleIntType](Value unpackedVal) -> Value {
256 auto asInt =
257 rewriter.create<arith::BitcastOp>(location: loc, args: equivIntType, args&: unpackedVal);
258 return rewriter.create<arith::ExtUIOp>(location: loc, args: shuffleIntType, args&: asInt);
259 };
260 auto unpackFn = [loc, &rewriter, equivIntType,
261 valueTy](Value packedVal) -> Value {
262 auto asInt =
263 rewriter.create<arith::TruncIOp>(location: loc, args: equivIntType, args&: packedVal);
264 return rewriter.create<arith::BitcastOp>(location: loc, args: valueTy, args&: asInt);
265 };
266
267 rewriter.replaceOp(
268 op, newValues: createSubgroupShuffleReduction(builder&: rewriter, loc, input: op.getValue(),
269 mode: op.getOp(), ci: *ci, packFn, unpackFn));
270 return success();
271 }
272
273private:
274 unsigned subgroupSize = 0;
275 unsigned shuffleBitwidth = 0;
276 bool matchClustered = false;
277};
278
279/// Lowers vector gpu subgroup reductions to a series of shuffles.
280struct VectorSubgroupReduceToShuffles final
281 : OpRewritePattern<gpu::SubgroupReduceOp> {
282 VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
283 unsigned shuffleBitwidth, bool matchClustered,
284 PatternBenefit benefit)
285 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
286 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
287
288 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
289 PatternRewriter &rewriter) const override {
290 if (op.getClusterSize().has_value() != matchClustered) {
291 return rewriter.notifyMatchFailure(
292 arg&: op, msg: llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to "
293 "only match {1}clustered ops",
294 Vals: matchClustered ? "non-" : "",
295 Vals: matchClustered ? "" : "non-"));
296 }
297
298 auto ci = getAndValidateClusterInfo(op, subgroupSize);
299 if (failed(Result: ci))
300 return failure();
301
302 auto vecTy = dyn_cast<VectorType>(Val: op.getType());
303 if (!vecTy)
304 return rewriter.notifyMatchFailure(arg&: op, msg: "value type is not a vector");
305
306 unsigned vecBitwidth =
307 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
308 if (vecBitwidth > shuffleBitwidth)
309 return rewriter.notifyMatchFailure(
310 arg&: op,
311 msg: llvm::formatv(Fmt: "vector type bitwidth too large ({0}), cannot lower "
312 "to shuffles of size {1}",
313 Vals&: vecBitwidth, Vals: shuffleBitwidth));
314
315 unsigned elementsPerShuffle =
316 shuffleBitwidth / vecTy.getElementTypeBitWidth();
317 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
318 return rewriter.notifyMatchFailure(
319 arg&: op, msg: "shuffle bitwidth is not a multiple of the element bitwidth");
320
321 Location loc = op.getLoc();
322
323 // If the reduced type is smaller than the native shuffle size, extend it,
324 // perform the shuffles, and extract at the end.
325 auto extendedVecTy = VectorType::get(
326 shape: static_cast<int64_t>(elementsPerShuffle), elementType: vecTy.getElementType());
327 Value extendedInput = op.getValue();
328 if (vecBitwidth < shuffleBitwidth) {
329 auto zero = rewriter.create<arith::ConstantOp>(
330 location: loc, args: rewriter.getZeroAttr(type: extendedVecTy));
331 extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
332 location: loc, args&: extendedInput, args&: zero, /*offsets=*/args: 0, /*strides=*/args: 1);
333 }
334
335 auto shuffleIntType = rewriter.getIntegerType(width: shuffleBitwidth);
336 auto shuffleVecType = VectorType::get(shape: 1, elementType: shuffleIntType);
337
338 auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
339 auto asIntVec =
340 rewriter.create<vector::BitCastOp>(location: loc, args: shuffleVecType, args&: unpackedVal);
341 return rewriter.create<vector::ExtractOp>(location: loc, args&: asIntVec, args: 0);
342 };
343 auto unpackFn = [loc, &rewriter, shuffleVecType,
344 extendedVecTy](Value packedVal) -> Value {
345 auto asIntVec =
346 rewriter.create<vector::BroadcastOp>(location: loc, args: shuffleVecType, args&: packedVal);
347 return rewriter.create<vector::BitCastOp>(location: loc, args: extendedVecTy, args&: asIntVec);
348 };
349
350 Value res = createSubgroupShuffleReduction(
351 builder&: rewriter, loc, input: extendedInput, mode: op.getOp(), ci: *ci, packFn, unpackFn);
352
353 if (vecBitwidth < shuffleBitwidth) {
354 res = rewriter.create<vector::ExtractStridedSliceOp>(
355 location: loc, args&: res, /*offsets=*/args: 0, /*sizes=*/args: vecTy.getNumElements(),
356 /*strides=*/args: 1);
357 }
358
359 rewriter.replaceOp(op, newValues: res);
360 return success();
361 }
362
363private:
364 unsigned subgroupSize = 0;
365 unsigned shuffleBitwidth = 0;
366 bool matchClustered = false;
367};
368
369static FailureOr<Value>
370createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
371 Value input, gpu::AllReduceOperation mode,
372 const ClusterInfo &ci, amdgpu::Chipset chipset) {
373 Location loc = op.getLoc();
374 Value dpp;
375 Value res = input;
376 constexpr int allRows = 0xf;
377 constexpr int allBanks = 0xf;
378 const bool boundCtrl = true;
379 if (ci.clusterSize >= 2) {
380 // Perform reduction between all lanes N <-> N+1.
381 dpp = rewriter.create<amdgpu::DPPOp>(
382 location: loc, args: res.getType(), args&: res, args&: res, args: amdgpu::DPPPerm::quad_perm,
383 args: rewriter.getI32ArrayAttr(values: {1, 0, 3, 2}), args: allRows, args: allBanks, args: boundCtrl);
384 res = vector::makeArithReduction(b&: rewriter, loc,
385 kind: gpu::convertReductionKind(mode), v1: res, acc: dpp);
386 }
387
388 if (ci.clusterSize >= 4) {
389 // Perform reduction between all lanes N <-> N+2.
390 dpp = rewriter.create<amdgpu::DPPOp>(
391 location: loc, args: res.getType(), args&: res, args&: res, args: amdgpu::DPPPerm::quad_perm,
392 args: rewriter.getI32ArrayAttr(values: {2, 3, 0, 1}), args: allRows, args: allBanks, args: boundCtrl);
393 res = vector::makeArithReduction(b&: rewriter, loc,
394 kind: gpu::convertReductionKind(mode), v1: res, acc: dpp);
395 }
396 if (ci.clusterSize >= 8) {
397 // Perform reduction between all lanes N <-> 7-N,
398 // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
399 dpp = rewriter.create<amdgpu::DPPOp>(
400 location: loc, args: res.getType(), args&: res, args&: res, args: amdgpu::DPPPerm::row_half_mirror,
401 args: rewriter.getUnitAttr(), args: allRows, args: allBanks, args: boundCtrl);
402 res = vector::makeArithReduction(b&: rewriter, loc,
403 kind: gpu::convertReductionKind(mode), v1: res, acc: dpp);
404 }
405 if (ci.clusterSize >= 16) {
406 // Perform reduction between all lanes N <-> 15-N,
407 // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
408 dpp = rewriter.create<amdgpu::DPPOp>(
409 location: loc, args: res.getType(), args&: res, args&: res, args: amdgpu::DPPPerm::row_mirror,
410 args: rewriter.getUnitAttr(), args: allRows, args: allBanks, args: boundCtrl);
411 res = vector::makeArithReduction(b&: rewriter, loc,
412 kind: gpu::convertReductionKind(mode), v1: res, acc: dpp);
413 }
414 if (ci.clusterSize >= 32) {
415 if (chipset.majorVersion <= 9) {
416 // Broadcast last value from each row to next row.
417 // Use row mask to avoid polluting rows 1 and 3.
418 dpp = rewriter.create<amdgpu::DPPOp>(
419 location: loc, args: res.getType(), args&: res, args&: res, args: amdgpu::DPPPerm::row_bcast_15,
420 args: rewriter.getUnitAttr(), args: 0xa, args: allBanks,
421 /*bound_ctrl*/ args: false);
422 res = vector::makeArithReduction(
423 b&: rewriter, loc, kind: gpu::convertReductionKind(mode), v1: res, acc: dpp);
424 } else if (chipset.majorVersion <= 12) {
425 // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
426 Value uint32Max = rewriter.create<arith::ConstantOp>(
427 location: loc, args: rewriter.getI32Type(), args: rewriter.getI32IntegerAttr(value: -1));
428 dpp = rewriter.create<ROCDL::PermlaneX16Op>(location: loc, args: res.getType(), args&: res, args&: res,
429 args&: uint32Max, args&: uint32Max,
430 /*fi=*/args: true,
431 /*bound_ctrl=*/args: false);
432 res = vector::makeArithReduction(
433 b&: rewriter, loc, kind: gpu::convertReductionKind(mode), v1: res, acc: dpp);
434 } else {
435 return rewriter.notifyMatchFailure(
436 arg&: op, msg: "Subgroup reduce lowering to DPP not currently supported for "
437 "this device.");
438 }
439 if (ci.subgroupSize == 32) {
440 Value lane31 = rewriter.create<arith::ConstantOp>(
441 location: loc, args: rewriter.getI32Type(), args: rewriter.getI32IntegerAttr(value: 31));
442 res = rewriter.create<ROCDL::ReadlaneOp>(location: loc, args: res.getType(), args&: res, args&: lane31);
443 }
444 }
445 if (ci.clusterSize >= 64) {
446 if (chipset.majorVersion <= 9) {
447 // Broadcast 31st lane value to rows 2 and 3.
448 dpp = rewriter.create<amdgpu::DPPOp>(
449 location: loc, args: res.getType(), args&: res, args&: res, args: amdgpu::DPPPerm::row_bcast_31,
450 args: rewriter.getUnitAttr(), args: 0xf, args: allBanks,
451 /*bound_ctrl*/ args: true);
452 res = vector::makeArithReduction(
453 b&: rewriter, loc, kind: gpu::convertReductionKind(mode), v1: dpp, acc: res);
454 // Obtain reduction from last rows, the previous rows are polluted.
455 Value lane63 = rewriter.create<arith::ConstantOp>(
456 location: loc, args: rewriter.getI32Type(), args: rewriter.getI32IntegerAttr(value: 63));
457 res = rewriter.create<ROCDL::ReadlaneOp>(location: loc, args: res.getType(), args&: res, args&: lane63);
458
459 } else if (chipset.majorVersion <= 12) {
460 // Assume reduction across 32 lanes has been done.
461 // Perform final reduction manually by summing values in lane 0 and
462 // lane 32.
463 Value lane31 = rewriter.create<arith::ConstantOp>(
464 location: loc, args: rewriter.getI32Type(), args: rewriter.getI32IntegerAttr(value: 31));
465 Value lane63 = rewriter.create<arith::ConstantOp>(
466 location: loc, args: rewriter.getI32Type(), args: rewriter.getI32IntegerAttr(value: 63));
467 lane31 =
468 rewriter.create<ROCDL::ReadlaneOp>(location: loc, args: res.getType(), args&: res, args&: lane31);
469 lane63 =
470 rewriter.create<ROCDL::ReadlaneOp>(location: loc, args: res.getType(), args&: res, args&: lane63);
471 res = vector::makeArithReduction(
472 b&: rewriter, loc, kind: gpu::convertReductionKind(mode), v1: lane31, acc: lane63);
473 } else {
474 return rewriter.notifyMatchFailure(
475 arg&: op, msg: "Subgroup reduce lowering to DPP not currently supported for "
476 "this device.");
477 }
478 }
479 assert(res.getType() == input.getType());
480 return res;
481}
482
483/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
484/// ops over scalar types. Assumes that the subgroup has
485/// `subgroupSize` lanes. Applicable only to AMD GPUs.
486struct ScalarSubgroupReduceToDPP final
487 : OpRewritePattern<gpu::SubgroupReduceOp> {
488 ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize,
489 bool matchClustered, amdgpu::Chipset chipset,
490 PatternBenefit benefit)
491 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
492 matchClustered(matchClustered), chipset(chipset) {}
493
494 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
495 PatternRewriter &rewriter) const override {
496 if (op.getClusterSize().has_value() != matchClustered) {
497 return rewriter.notifyMatchFailure(
498 arg&: op, msg: llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to "
499 "only match {1}clustered ops",
500 Vals: matchClustered ? "non-" : "",
501 Vals: matchClustered ? "" : "non-"));
502 }
503 auto ci = getAndValidateClusterInfo(op, subgroupSize);
504 if (failed(Result: ci))
505 return failure();
506
507 if (ci->clusterStride != 1)
508 return rewriter.notifyMatchFailure(
509 arg&: op, msg: "Subgroup reductions using DPP are currently only available for "
510 "clusters of contiguous lanes.");
511
512 Type valueTy = op.getType();
513 if (!valueTy.isIntOrFloat())
514 return rewriter.notifyMatchFailure(
515 arg&: op, msg: "Value type is not a compatible scalar.");
516
517 FailureOr<Value> dpp = createSubgroupDPPReduction(
518 rewriter, op, input: op.getValue(), mode: op.getOp(), ci: *ci, chipset);
519 if (failed(Result: dpp))
520 return failure();
521
522 rewriter.replaceOp(op, newValues: dpp.value());
523 return success();
524 }
525
526private:
527 unsigned subgroupSize = 0;
528 bool matchClustered = false;
529 amdgpu::Chipset chipset;
530};
531} // namespace
532
533void mlir::populateGpuBreakDownSubgroupReducePatterns(
534 RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
535 PatternBenefit benefit) {
536 patterns.add<BreakDownSubgroupReduce>(arg: patterns.getContext(),
537 args&: maxShuffleBitwidth, args&: benefit);
538 patterns.add<ScalarizeSingleElementReduce>(arg: patterns.getContext(), args&: benefit);
539}
540
541void mlir::populateGpuLowerSubgroupReduceToDPPPatterns(
542 RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
543 PatternBenefit benefit) {
544 patterns.add<ScalarSubgroupReduceToDPP>(arg: patterns.getContext(), args&: subgroupSize,
545 /*matchClustered=*/args: false, args&: chipset,
546 args&: benefit);
547}
548
549void mlir::populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
550 RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
551 PatternBenefit benefit) {
552 patterns.add<ScalarSubgroupReduceToDPP>(arg: patterns.getContext(), args&: subgroupSize,
553 /*matchClustered=*/args: true, args&: chipset,
554 args&: benefit);
555}
556
557void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
558 RewritePatternSet &patterns, unsigned subgroupSize,
559 unsigned shuffleBitwidth, PatternBenefit benefit) {
560 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
561 arg: patterns.getContext(), args&: subgroupSize, args&: shuffleBitwidth,
562 /*matchClustered=*/args: false, args&: benefit);
563}
564
565void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
566 RewritePatternSet &patterns, unsigned subgroupSize,
567 unsigned shuffleBitwidth, PatternBenefit benefit) {
568 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
569 arg: patterns.getContext(), args&: subgroupSize, args&: shuffleBitwidth,
570 /*matchClustered=*/args: true, args&: benefit);
571}
572

source code of mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp