1//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements gradual lowering of `gpu.subgroup_reduce` ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/GPU/Transforms/Passes.h"
18#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
19#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "llvm/Support/ErrorHandling.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/MathExtras.h"
28#include <cassert>
29#include <cstdint>
30
31using namespace mlir;
32
33namespace {
34
35/// Example, assumes `maxShuffleBitwidth` equal to 32:
36/// ```
37/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
38/// ==>
39/// %v0 = arith.constant dense<0.0> : vector<3xf16>
40/// %e0 = vector.extract_strided_slice %x
41/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
42/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
43/// %v1 = vector.insert_strided_slice %r0, %v0
44/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
45/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
46/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
47/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
48/// ```
49struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
50 BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
51 PatternBenefit benefit)
52 : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
53 }
54
55 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
56 PatternRewriter &rewriter) const override {
57 auto vecTy = dyn_cast<VectorType>(op.getType());
58 if (!vecTy || vecTy.getNumElements() < 2)
59 return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
60
61 assert(vecTy.getRank() == 1 && "Unexpected vector type");
62 assert(!vecTy.isScalable() && "Unexpected vector type");
63
64 Type elemTy = vecTy.getElementType();
65 unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
66 if (elemBitwidth >= maxShuffleBitwidth)
67 return rewriter.notifyMatchFailure(
68 op, llvm::formatv(Fmt: "element type too large ({0}), cannot break down "
69 "into vectors of bitwidth {1} or less",
70 Vals&: elemBitwidth, Vals: maxShuffleBitwidth));
71
72 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
73 assert(elementsPerShuffle >= 1);
74
75 unsigned numNewReductions =
76 llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
77 assert(numNewReductions >= 1);
78 if (numNewReductions == 1)
79 return rewriter.notifyMatchFailure(op, "nothing to break down");
80
81 Location loc = op.getLoc();
82 Value res =
83 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
84
85 for (unsigned i = 0; i != numNewReductions; ++i) {
86 int64_t startIdx = i * elementsPerShuffle;
87 int64_t endIdx =
88 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
89 int64_t numElems = endIdx - startIdx;
90
91 Value extracted;
92 if (numElems == 1) {
93 extracted =
94 rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
95 } else {
96 extracted = rewriter.create<vector::ExtractStridedSliceOp>(
97 loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
98 /*strides=*/1);
99 }
100
101 Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
102 loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
103 op.getClusterStride());
104 if (numElems == 1) {
105 res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
106 continue;
107 }
108
109 res = rewriter.create<vector::InsertStridedSliceOp>(
110 loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
111 }
112
113 rewriter.replaceOp(op, res);
114 return success();
115 }
116
117private:
118 unsigned maxShuffleBitwidth = 0;
119};
120
121/// Example:
122/// ```
123/// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
124/// ==>
125/// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
126/// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
127/// %a = vector.broadcast %r0 : f32 to vector<1xf32>
128/// ```
129struct ScalarizeSingleElementReduce final
130 : OpRewritePattern<gpu::SubgroupReduceOp> {
131 using OpRewritePattern::OpRewritePattern;
132
133 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
134 PatternRewriter &rewriter) const override {
135 auto vecTy = dyn_cast<VectorType>(op.getType());
136 if (!vecTy || vecTy.getNumElements() != 1)
137 return rewriter.notifyMatchFailure(op, "not a single-element reduction");
138
139 assert(vecTy.getRank() == 1 && "Unexpected vector type");
140 assert(!vecTy.isScalable() && "Unexpected vector type");
141 Location loc = op.getLoc();
142 Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
143 Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
144 loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
145 op.getClusterStride());
146 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
147 return success();
148 }
149};
150
151struct ClusterInfo {
152 unsigned clusterStride;
153 unsigned clusterSize;
154 unsigned subgroupSize;
155};
156
157static FailureOr<ClusterInfo>
158getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
159 assert(llvm::isPowerOf2_32(subgroupSize));
160
161 std::optional<uint32_t> clusterSize = op.getClusterSize();
162 assert(!clusterSize ||
163 llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
164 if (clusterSize && *clusterSize > subgroupSize)
165 return op.emitOpError()
166 << "cluster size " << *clusterSize
167 << " is greater than subgroup size " << subgroupSize;
168 unsigned effectiveClusterSize = clusterSize.value_or(u&: subgroupSize);
169
170 auto clusterStride = op.getClusterStride();
171 assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
172 if (clusterStride >= subgroupSize)
173 return op.emitOpError()
174 << "cluster stride " << clusterStride
175 << " is not less than subgroup size " << subgroupSize;
176
177 return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
178}
179
180/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
181/// and `unpackFn` to convert to the native shuffle type and to the reduction
182/// type, respectively. For example, with `input` of type `f16`, `packFn` could
183/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
184/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
185/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
186/// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
187/// lanes within a cluster, reducing all lanes in each cluster in parallel.
188Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
189 Value input, gpu::AllReduceOperation mode,
190 const ClusterInfo &ci,
191 function_ref<Value(Value)> packFn,
192 function_ref<Value(Value)> unpackFn) {
193 // Lane value always stays in the original type. We use it to perform arith
194 // reductions.
195 Value laneVal = input;
196 // Parallel reduction using butterfly shuffles.
197 for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
198 i <<= 1) {
199 Value shuffled = builder
200 .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
201 /*width=*/ci.subgroupSize,
202 /*mode=*/gpu::ShuffleMode::XOR)
203 .getShuffleResult();
204 laneVal = vector::makeArithReduction(builder, loc,
205 gpu::convertReductionKind(mode),
206 laneVal, unpackFn(shuffled));
207 assert(laneVal.getType() == input.getType());
208 }
209
210 return laneVal;
211}
212
213/// Lowers scalar gpu subgroup reductions to a series of shuffles.
214struct ScalarSubgroupReduceToShuffles final
215 : OpRewritePattern<gpu::SubgroupReduceOp> {
216 ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
217 unsigned shuffleBitwidth, bool matchClustered,
218 PatternBenefit benefit)
219 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
220 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
221
222 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
223 PatternRewriter &rewriter) const override {
224 if (op.getClusterSize().has_value() != matchClustered) {
225 return rewriter.notifyMatchFailure(
226 op, llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to "
227 "only match {1}clustered ops",
228 Vals: matchClustered ? "non-" : "",
229 Vals: matchClustered ? "" : "non-"));
230 }
231
232 auto ci = getAndValidateClusterInfo(op, subgroupSize);
233 if (failed(ci))
234 return failure();
235
236 Type valueTy = op.getType();
237 unsigned elemBitwidth =
238 getElementTypeOrSelf(type: valueTy).getIntOrFloatBitWidth();
239 if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
240 return rewriter.notifyMatchFailure(
241 op, "value type is not a compatible scalar");
242
243 Location loc = op.getLoc();
244 // Since this is already a native shuffle scalar, no packing is necessary.
245 if (elemBitwidth == shuffleBitwidth) {
246 auto identityFn = [](Value v) { return v; };
247 rewriter.replaceOp(op, createSubgroupShuffleReduction(
248 rewriter, loc, op.getValue(), op.getOp(), *ci,
249 identityFn, identityFn));
250 return success();
251 }
252
253 auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
254 auto equivIntType = rewriter.getIntegerType(elemBitwidth);
255 auto packFn = [loc, &rewriter, equivIntType,
256 shuffleIntType](Value unpackedVal) -> Value {
257 auto asInt =
258 rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
259 return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
260 };
261 auto unpackFn = [loc, &rewriter, equivIntType,
262 valueTy](Value packedVal) -> Value {
263 auto asInt =
264 rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
265 return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
266 };
267
268 rewriter.replaceOp(
269 op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
270 op.getOp(), *ci, packFn, unpackFn));
271 return success();
272 }
273
274private:
275 unsigned subgroupSize = 0;
276 unsigned shuffleBitwidth = 0;
277 bool matchClustered = false;
278};
279
280/// Lowers vector gpu subgroup reductions to a series of shuffles.
281struct VectorSubgroupReduceToShuffles final
282 : OpRewritePattern<gpu::SubgroupReduceOp> {
283 VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
284 unsigned shuffleBitwidth, bool matchClustered,
285 PatternBenefit benefit)
286 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
287 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
288
289 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
290 PatternRewriter &rewriter) const override {
291 if (op.getClusterSize().has_value() != matchClustered) {
292 return rewriter.notifyMatchFailure(
293 op, llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to "
294 "only match {1}clustered ops",
295 Vals: matchClustered ? "non-" : "",
296 Vals: matchClustered ? "" : "non-"));
297 }
298
299 auto ci = getAndValidateClusterInfo(op, subgroupSize);
300 if (failed(ci))
301 return failure();
302
303 auto vecTy = dyn_cast<VectorType>(op.getType());
304 if (!vecTy)
305 return rewriter.notifyMatchFailure(op, "value type is not a vector");
306
307 unsigned vecBitwidth =
308 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
309 if (vecBitwidth > shuffleBitwidth)
310 return rewriter.notifyMatchFailure(
311 op,
312 llvm::formatv(Fmt: "vector type bitwidth too large ({0}), cannot lower "
313 "to shuffles of size {1}",
314 Vals&: vecBitwidth, Vals: shuffleBitwidth));
315
316 unsigned elementsPerShuffle =
317 shuffleBitwidth / vecTy.getElementTypeBitWidth();
318 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
319 return rewriter.notifyMatchFailure(
320 op, "shuffle bitwidth is not a multiple of the element bitwidth");
321
322 Location loc = op.getLoc();
323
324 // If the reduced type is smaller than the native shuffle size, extend it,
325 // perform the shuffles, and extract at the end.
326 auto extendedVecTy = VectorType::get(
327 static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
328 Value extendedInput = op.getValue();
329 if (vecBitwidth < shuffleBitwidth) {
330 auto zero = rewriter.create<arith::ConstantOp>(
331 loc, rewriter.getZeroAttr(extendedVecTy));
332 extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
333 loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
334 }
335
336 auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
337 auto shuffleVecType = VectorType::get(1, shuffleIntType);
338
339 auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
340 auto asIntVec =
341 rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
342 return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
343 };
344 auto unpackFn = [loc, &rewriter, shuffleVecType,
345 extendedVecTy](Value packedVal) -> Value {
346 auto asIntVec =
347 rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
348 return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
349 };
350
351 Value res = createSubgroupShuffleReduction(
352 rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
353
354 if (vecBitwidth < shuffleBitwidth) {
355 res = rewriter.create<vector::ExtractStridedSliceOp>(
356 loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
357 /*strides=*/1);
358 }
359
360 rewriter.replaceOp(op, res);
361 return success();
362 }
363
364private:
365 unsigned subgroupSize = 0;
366 unsigned shuffleBitwidth = 0;
367 bool matchClustered = false;
368};
369
370static FailureOr<Value>
371createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372 Value input, gpu::AllReduceOperation mode,
373 const ClusterInfo &ci, amdgpu::Chipset chipset) {
374 Location loc = op.getLoc();
375 Value dpp;
376 Value res = input;
377 constexpr int allRows = 0xf;
378 constexpr int allBanks = 0xf;
379 const bool boundCtrl = true;
380 if (ci.clusterSize >= 2) {
381 // Perform reduction between all lanes N <-> N+1.
382 dpp = rewriter.create<amdgpu::DPPOp>(
383 loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
384 rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
385 res = vector::makeArithReduction(rewriter, loc,
386 gpu::convertReductionKind(mode), res, dpp);
387 }
388
389 if (ci.clusterSize >= 4) {
390 // Perform reduction between all lanes N <-> N+2.
391 dpp = rewriter.create<amdgpu::DPPOp>(
392 loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
393 rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
394 res = vector::makeArithReduction(rewriter, loc,
395 gpu::convertReductionKind(mode), res, dpp);
396 }
397 if (ci.clusterSize >= 8) {
398 // Perform reduction between all lanes N <-> 7-N,
399 // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400 dpp = rewriter.create<amdgpu::DPPOp>(
401 loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror,
402 rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
403 res = vector::makeArithReduction(rewriter, loc,
404 gpu::convertReductionKind(mode), res, dpp);
405 }
406 if (ci.clusterSize >= 16) {
407 // Perform reduction between all lanes N <-> 15-N,
408 // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
409 dpp = rewriter.create<amdgpu::DPPOp>(
410 loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror,
411 rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
412 res = vector::makeArithReduction(rewriter, loc,
413 gpu::convertReductionKind(mode), res, dpp);
414 }
415 if (ci.clusterSize >= 32) {
416 if (chipset.majorVersion <= 9) {
417 // Broadcast last value from each row to next row.
418 // Use row mask to avoid polluting rows 1 and 3.
419 dpp = rewriter.create<amdgpu::DPPOp>(
420 loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15,
421 rewriter.getUnitAttr(), 0xa, allBanks,
422 /*bound_ctrl*/ false);
423 res = vector::makeArithReduction(
424 rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
425 } else if (chipset.majorVersion <= 12) {
426 // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
427 Value uint32Max = rewriter.create<arith::ConstantOp>(
428 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1));
429 dpp = rewriter.create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res,
430 uint32Max, uint32Max,
431 /*fi=*/true,
432 /*bound_ctrl=*/false);
433 res = vector::makeArithReduction(
434 rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
435 } else {
436 return rewriter.notifyMatchFailure(
437 op, "Subgroup reduce lowering to DPP not currently supported for "
438 "this device.");
439 }
440 if (ci.subgroupSize == 32) {
441 Value lane31 = rewriter.create<arith::ConstantOp>(
442 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
443 res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31);
444 }
445 }
446 if (ci.clusterSize >= 64) {
447 if (chipset.majorVersion <= 9) {
448 // Broadcast 31st lane value to rows 2 and 3.
449 dpp = rewriter.create<amdgpu::DPPOp>(
450 loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
451 rewriter.getUnitAttr(), 0xf, allBanks,
452 /*bound_ctrl*/ true);
453 res = vector::makeArithReduction(
454 rewriter, loc, gpu::convertReductionKind(mode), dpp, res);
455 // Obtain reduction from last rows, the previous rows are polluted.
456 Value lane63 = rewriter.create<arith::ConstantOp>(
457 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
458 res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
459
460 } else if (chipset.majorVersion <= 12) {
461 // Assume reduction across 32 lanes has been done.
462 // Perform final reduction manually by summing values in lane 0 and
463 // lane 32.
464 Value lane31 = rewriter.create<arith::ConstantOp>(
465 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
466 Value lane63 = rewriter.create<arith::ConstantOp>(
467 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
468 lane31 =
469 rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31);
470 lane63 =
471 rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63);
472 res = vector::makeArithReduction(
473 rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63);
474 } else {
475 return rewriter.notifyMatchFailure(
476 op, "Subgroup reduce lowering to DPP not currently supported for "
477 "this device.");
478 }
479 }
480 assert(res.getType() == input.getType());
481 return res;
482}
483
484/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
485/// ops over scalar types. Assumes that the subgroup has
486/// `subgroupSize` lanes. Applicable only to AMD GPUs.
487struct ScalarSubgroupReduceToDPP final
488 : OpRewritePattern<gpu::SubgroupReduceOp> {
489 ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize,
490 bool matchClustered, amdgpu::Chipset chipset,
491 PatternBenefit benefit)
492 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
493 matchClustered(matchClustered), chipset(chipset) {}
494
495 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
496 PatternRewriter &rewriter) const override {
497 if (op.getClusterSize().has_value() != matchClustered) {
498 return rewriter.notifyMatchFailure(
499 op, llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to "
500 "only match {1}clustered ops",
501 Vals: matchClustered ? "non-" : "",
502 Vals: matchClustered ? "" : "non-"));
503 }
504 auto ci = getAndValidateClusterInfo(op, subgroupSize);
505 if (failed(ci))
506 return failure();
507
508 if (ci->clusterStride != 1)
509 return rewriter.notifyMatchFailure(
510 op, "Subgroup reductions using DPP are currently only available for "
511 "clusters of contiguous lanes.");
512
513 Type valueTy = op.getType();
514 if (!valueTy.isIntOrFloat())
515 return rewriter.notifyMatchFailure(
516 op, "Value type is not a compatible scalar.");
517
518 FailureOr<Value> dpp = createSubgroupDPPReduction(
519 rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
520 if (failed(Result: dpp))
521 return failure();
522
523 rewriter.replaceOp(op, dpp.value());
524 return success();
525 }
526
527private:
528 unsigned subgroupSize = 0;
529 bool matchClustered = false;
530 amdgpu::Chipset chipset;
531};
532} // namespace
533
534void mlir::populateGpuBreakDownSubgroupReducePatterns(
535 RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
536 PatternBenefit benefit) {
537 patterns.add<BreakDownSubgroupReduce>(arg: patterns.getContext(),
538 args&: maxShuffleBitwidth, args&: benefit);
539 patterns.add<ScalarizeSingleElementReduce>(arg: patterns.getContext(), args&: benefit);
540}
541
542void mlir::populateGpuLowerSubgroupReduceToDPPPatterns(
543 RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
544 PatternBenefit benefit) {
545 patterns.add<ScalarSubgroupReduceToDPP>(arg: patterns.getContext(), args&: subgroupSize,
546 /*matchClustered=*/args: false, args&: chipset,
547 args&: benefit);
548}
549
550void mlir::populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
551 RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
552 PatternBenefit benefit) {
553 patterns.add<ScalarSubgroupReduceToDPP>(arg: patterns.getContext(), args&: subgroupSize,
554 /*matchClustered=*/args: true, args&: chipset,
555 args&: benefit);
556}
557
558void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
559 RewritePatternSet &patterns, unsigned subgroupSize,
560 unsigned shuffleBitwidth, PatternBenefit benefit) {
561 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
562 arg: patterns.getContext(), args&: subgroupSize, args&: shuffleBitwidth,
563 /*matchClustered=*/args: false, args&: benefit);
564}
565
566void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
567 RewritePatternSet &patterns, unsigned subgroupSize,
568 unsigned shuffleBitwidth, PatternBenefit benefit) {
569 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
570 arg: patterns.getContext(), args&: subgroupSize, args&: shuffleBitwidth,
571 /*matchClustered=*/args: true, args&: benefit);
572}
573

source code of mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp