1 | //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implements gradual lowering of `gpu.subgroup_reduce` ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
14 | #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
18 | #include "mlir/Dialect/GPU/Utils/GPUUtils.h" |
19 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/Location.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/IR/TypeUtilities.h" |
25 | #include "llvm/Support/ErrorHandling.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | #include "llvm/Support/MathExtras.h" |
28 | #include <cassert> |
29 | #include <cstdint> |
30 | |
31 | using namespace mlir; |
32 | |
33 | namespace { |
34 | |
35 | /// Example, assumes `maxShuffleBitwidth` equal to 32: |
36 | /// ``` |
37 | /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16> |
38 | /// ==> |
39 | /// %v0 = arith.constant dense<0.0> : vector<3xf16> |
40 | /// %e0 = vector.extract_strided_slice %x |
41 | /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32> |
42 | /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16> |
43 | /// %v1 = vector.insert_strided_slice %r0, %v0 |
44 | /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32> |
45 | /// %e1 = vector.extract %x[2] : f16 from vector<2xf16> |
46 | /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16 |
47 | /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16> |
48 | /// ``` |
49 | struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { |
50 | BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth, |
51 | PatternBenefit benefit) |
52 | : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) { |
53 | } |
54 | |
55 | LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
56 | PatternRewriter &rewriter) const override { |
57 | auto vecTy = dyn_cast<VectorType>(op.getType()); |
58 | if (!vecTy || vecTy.getNumElements() < 2) |
59 | return rewriter.notifyMatchFailure(op, "not a multi-element reduction" ); |
60 | |
61 | assert(vecTy.getRank() == 1 && "Unexpected vector type" ); |
62 | assert(!vecTy.isScalable() && "Unexpected vector type" ); |
63 | |
64 | Type elemTy = vecTy.getElementType(); |
65 | unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); |
66 | if (elemBitwidth >= maxShuffleBitwidth) |
67 | return rewriter.notifyMatchFailure( |
68 | op, llvm::formatv(Fmt: "element type too large ({0}), cannot break down " |
69 | "into vectors of bitwidth {1} or less" , |
70 | Vals&: elemBitwidth, Vals: maxShuffleBitwidth)); |
71 | |
72 | unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth; |
73 | assert(elementsPerShuffle >= 1); |
74 | |
75 | unsigned numNewReductions = |
76 | llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle); |
77 | assert(numNewReductions >= 1); |
78 | if (numNewReductions == 1) |
79 | return rewriter.notifyMatchFailure(op, "nothing to break down" ); |
80 | |
81 | Location loc = op.getLoc(); |
82 | Value res = |
83 | rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy)); |
84 | |
85 | for (unsigned i = 0; i != numNewReductions; ++i) { |
86 | int64_t startIdx = i * elementsPerShuffle; |
87 | int64_t endIdx = |
88 | std::min(startIdx + elementsPerShuffle, vecTy.getNumElements()); |
89 | int64_t numElems = endIdx - startIdx; |
90 | |
91 | Value ; |
92 | if (numElems == 1) { |
93 | extracted = |
94 | rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx); |
95 | } else { |
96 | extracted = rewriter.create<vector::ExtractStridedSliceOp>( |
97 | loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems, |
98 | /*strides=*/1); |
99 | } |
100 | |
101 | Value reduce = rewriter.create<gpu::SubgroupReduceOp>( |
102 | loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(), |
103 | op.getClusterStride()); |
104 | if (numElems == 1) { |
105 | res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx); |
106 | continue; |
107 | } |
108 | |
109 | res = rewriter.create<vector::InsertStridedSliceOp>( |
110 | loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1); |
111 | } |
112 | |
113 | rewriter.replaceOp(op, res); |
114 | return success(); |
115 | } |
116 | |
117 | private: |
118 | unsigned maxShuffleBitwidth = 0; |
119 | }; |
120 | |
121 | /// Example: |
122 | /// ``` |
123 | /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32> |
124 | /// ==> |
125 | /// %e0 = vector.extract %x[0] : f32 from vector<1xf32> |
126 | /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32 |
127 | /// %a = vector.broadcast %r0 : f32 to vector<1xf32> |
128 | /// ``` |
129 | struct ScalarizeSingleElementReduce final |
130 | : OpRewritePattern<gpu::SubgroupReduceOp> { |
131 | using OpRewritePattern::OpRewritePattern; |
132 | |
133 | LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
134 | PatternRewriter &rewriter) const override { |
135 | auto vecTy = dyn_cast<VectorType>(op.getType()); |
136 | if (!vecTy || vecTy.getNumElements() != 1) |
137 | return rewriter.notifyMatchFailure(op, "not a single-element reduction" ); |
138 | |
139 | assert(vecTy.getRank() == 1 && "Unexpected vector type" ); |
140 | assert(!vecTy.isScalable() && "Unexpected vector type" ); |
141 | Location loc = op.getLoc(); |
142 | Value = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0); |
143 | Value reduce = rewriter.create<gpu::SubgroupReduceOp>( |
144 | loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(), |
145 | op.getClusterStride()); |
146 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce); |
147 | return success(); |
148 | } |
149 | }; |
150 | |
151 | struct ClusterInfo { |
152 | unsigned clusterStride; |
153 | unsigned clusterSize; |
154 | unsigned subgroupSize; |
155 | }; |
156 | |
157 | static FailureOr<ClusterInfo> |
158 | getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) { |
159 | assert(llvm::isPowerOf2_32(subgroupSize)); |
160 | |
161 | std::optional<uint32_t> clusterSize = op.getClusterSize(); |
162 | assert(!clusterSize || |
163 | llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this. |
164 | if (clusterSize && *clusterSize > subgroupSize) |
165 | return op.emitOpError() |
166 | << "cluster size " << *clusterSize |
167 | << " is greater than subgroup size " << subgroupSize; |
168 | unsigned effectiveClusterSize = clusterSize.value_or(u&: subgroupSize); |
169 | |
170 | auto clusterStride = op.getClusterStride(); |
171 | assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this. |
172 | if (clusterStride >= subgroupSize) |
173 | return op.emitOpError() |
174 | << "cluster stride " << clusterStride |
175 | << " is not less than subgroup size " << subgroupSize; |
176 | |
177 | return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize}; |
178 | } |
179 | |
180 | /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn` |
181 | /// and `unpackFn` to convert to the native shuffle type and to the reduction |
182 | /// type, respectively. For example, with `input` of type `f16`, `packFn` could |
183 | /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn` |
184 | /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that |
185 | /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of |
186 | /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for |
187 | /// lanes within a cluster, reducing all lanes in each cluster in parallel. |
188 | Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc, |
189 | Value input, gpu::AllReduceOperation mode, |
190 | const ClusterInfo &ci, |
191 | function_ref<Value(Value)> packFn, |
192 | function_ref<Value(Value)> unpackFn) { |
193 | // Lane value always stays in the original type. We use it to perform arith |
194 | // reductions. |
195 | Value laneVal = input; |
196 | // Parallel reduction using butterfly shuffles. |
197 | for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize; |
198 | i <<= 1) { |
199 | Value shuffled = builder |
200 | .create<gpu::ShuffleOp>(loc, packFn(laneVal), i, |
201 | /*width=*/ci.subgroupSize, |
202 | /*mode=*/gpu::ShuffleMode::XOR) |
203 | .getShuffleResult(); |
204 | laneVal = vector::makeArithReduction(builder, loc, |
205 | gpu::convertReductionKind(mode), |
206 | laneVal, unpackFn(shuffled)); |
207 | assert(laneVal.getType() == input.getType()); |
208 | } |
209 | |
210 | return laneVal; |
211 | } |
212 | |
213 | /// Lowers scalar gpu subgroup reductions to a series of shuffles. |
214 | struct ScalarSubgroupReduceToShuffles final |
215 | : OpRewritePattern<gpu::SubgroupReduceOp> { |
216 | ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, |
217 | unsigned shuffleBitwidth, bool matchClustered, |
218 | PatternBenefit benefit) |
219 | : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), |
220 | shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {} |
221 | |
222 | LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
223 | PatternRewriter &rewriter) const override { |
224 | if (op.getClusterSize().has_value() != matchClustered) { |
225 | return rewriter.notifyMatchFailure( |
226 | op, llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to " |
227 | "only match {1}clustered ops" , |
228 | Vals: matchClustered ? "non-" : "" , |
229 | Vals: matchClustered ? "" : "non-" )); |
230 | } |
231 | |
232 | auto ci = getAndValidateClusterInfo(op, subgroupSize); |
233 | if (failed(ci)) |
234 | return failure(); |
235 | |
236 | Type valueTy = op.getType(); |
237 | unsigned elemBitwidth = |
238 | getElementTypeOrSelf(type: valueTy).getIntOrFloatBitWidth(); |
239 | if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth) |
240 | return rewriter.notifyMatchFailure( |
241 | op, "value type is not a compatible scalar" ); |
242 | |
243 | Location loc = op.getLoc(); |
244 | // Since this is already a native shuffle scalar, no packing is necessary. |
245 | if (elemBitwidth == shuffleBitwidth) { |
246 | auto identityFn = [](Value v) { return v; }; |
247 | rewriter.replaceOp(op, createSubgroupShuffleReduction( |
248 | rewriter, loc, op.getValue(), op.getOp(), *ci, |
249 | identityFn, identityFn)); |
250 | return success(); |
251 | } |
252 | |
253 | auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); |
254 | auto equivIntType = rewriter.getIntegerType(elemBitwidth); |
255 | auto packFn = [loc, &rewriter, equivIntType, |
256 | shuffleIntType](Value unpackedVal) -> Value { |
257 | auto asInt = |
258 | rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal); |
259 | return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt); |
260 | }; |
261 | auto unpackFn = [loc, &rewriter, equivIntType, |
262 | valueTy](Value packedVal) -> Value { |
263 | auto asInt = |
264 | rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal); |
265 | return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt); |
266 | }; |
267 | |
268 | rewriter.replaceOp( |
269 | op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(), |
270 | op.getOp(), *ci, packFn, unpackFn)); |
271 | return success(); |
272 | } |
273 | |
274 | private: |
275 | unsigned subgroupSize = 0; |
276 | unsigned shuffleBitwidth = 0; |
277 | bool matchClustered = false; |
278 | }; |
279 | |
280 | /// Lowers vector gpu subgroup reductions to a series of shuffles. |
281 | struct VectorSubgroupReduceToShuffles final |
282 | : OpRewritePattern<gpu::SubgroupReduceOp> { |
283 | VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, |
284 | unsigned shuffleBitwidth, bool matchClustered, |
285 | PatternBenefit benefit) |
286 | : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), |
287 | shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {} |
288 | |
289 | LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
290 | PatternRewriter &rewriter) const override { |
291 | if (op.getClusterSize().has_value() != matchClustered) { |
292 | return rewriter.notifyMatchFailure( |
293 | op, llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to " |
294 | "only match {1}clustered ops" , |
295 | Vals: matchClustered ? "non-" : "" , |
296 | Vals: matchClustered ? "" : "non-" )); |
297 | } |
298 | |
299 | auto ci = getAndValidateClusterInfo(op, subgroupSize); |
300 | if (failed(ci)) |
301 | return failure(); |
302 | |
303 | auto vecTy = dyn_cast<VectorType>(op.getType()); |
304 | if (!vecTy) |
305 | return rewriter.notifyMatchFailure(op, "value type is not a vector" ); |
306 | |
307 | unsigned vecBitwidth = |
308 | vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); |
309 | if (vecBitwidth > shuffleBitwidth) |
310 | return rewriter.notifyMatchFailure( |
311 | op, |
312 | llvm::formatv(Fmt: "vector type bitwidth too large ({0}), cannot lower " |
313 | "to shuffles of size {1}" , |
314 | Vals&: vecBitwidth, Vals: shuffleBitwidth)); |
315 | |
316 | unsigned elementsPerShuffle = |
317 | shuffleBitwidth / vecTy.getElementTypeBitWidth(); |
318 | if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth) |
319 | return rewriter.notifyMatchFailure( |
320 | op, "shuffle bitwidth is not a multiple of the element bitwidth" ); |
321 | |
322 | Location loc = op.getLoc(); |
323 | |
324 | // If the reduced type is smaller than the native shuffle size, extend it, |
325 | // perform the shuffles, and extract at the end. |
326 | auto extendedVecTy = VectorType::get( |
327 | static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType()); |
328 | Value extendedInput = op.getValue(); |
329 | if (vecBitwidth < shuffleBitwidth) { |
330 | auto zero = rewriter.create<arith::ConstantOp>( |
331 | loc, rewriter.getZeroAttr(extendedVecTy)); |
332 | extendedInput = rewriter.create<vector::InsertStridedSliceOp>( |
333 | loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1); |
334 | } |
335 | |
336 | auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); |
337 | auto shuffleVecType = VectorType::get(1, shuffleIntType); |
338 | |
339 | auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value { |
340 | auto asIntVec = |
341 | rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal); |
342 | return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0); |
343 | }; |
344 | auto unpackFn = [loc, &rewriter, shuffleVecType, |
345 | extendedVecTy](Value packedVal) -> Value { |
346 | auto asIntVec = |
347 | rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal); |
348 | return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec); |
349 | }; |
350 | |
351 | Value res = createSubgroupShuffleReduction( |
352 | rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn); |
353 | |
354 | if (vecBitwidth < shuffleBitwidth) { |
355 | res = rewriter.create<vector::ExtractStridedSliceOp>( |
356 | loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(), |
357 | /*strides=*/1); |
358 | } |
359 | |
360 | rewriter.replaceOp(op, res); |
361 | return success(); |
362 | } |
363 | |
364 | private: |
365 | unsigned subgroupSize = 0; |
366 | unsigned shuffleBitwidth = 0; |
367 | bool matchClustered = false; |
368 | }; |
369 | |
370 | static FailureOr<Value> |
371 | createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, |
372 | Value input, gpu::AllReduceOperation mode, |
373 | const ClusterInfo &ci, amdgpu::Chipset chipset) { |
374 | Location loc = op.getLoc(); |
375 | Value dpp; |
376 | Value res = input; |
377 | constexpr int allRows = 0xf; |
378 | constexpr int allBanks = 0xf; |
379 | const bool boundCtrl = true; |
380 | if (ci.clusterSize >= 2) { |
381 | // Perform reduction between all lanes N <-> N+1. |
382 | dpp = rewriter.create<amdgpu::DPPOp>( |
383 | loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm, |
384 | rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl); |
385 | res = vector::makeArithReduction(rewriter, loc, |
386 | gpu::convertReductionKind(mode), res, dpp); |
387 | } |
388 | |
389 | if (ci.clusterSize >= 4) { |
390 | // Perform reduction between all lanes N <-> N+2. |
391 | dpp = rewriter.create<amdgpu::DPPOp>( |
392 | loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm, |
393 | rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl); |
394 | res = vector::makeArithReduction(rewriter, loc, |
395 | gpu::convertReductionKind(mode), res, dpp); |
396 | } |
397 | if (ci.clusterSize >= 8) { |
398 | // Perform reduction between all lanes N <-> 7-N, |
399 | // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4]. |
400 | dpp = rewriter.create<amdgpu::DPPOp>( |
401 | loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror, |
402 | rewriter.getUnitAttr(), allRows, allBanks, boundCtrl); |
403 | res = vector::makeArithReduction(rewriter, loc, |
404 | gpu::convertReductionKind(mode), res, dpp); |
405 | } |
406 | if (ci.clusterSize >= 16) { |
407 | // Perform reduction between all lanes N <-> 15-N, |
408 | // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8]. |
409 | dpp = rewriter.create<amdgpu::DPPOp>( |
410 | loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror, |
411 | rewriter.getUnitAttr(), allRows, allBanks, boundCtrl); |
412 | res = vector::makeArithReduction(rewriter, loc, |
413 | gpu::convertReductionKind(mode), res, dpp); |
414 | } |
415 | if (ci.clusterSize >= 32) { |
416 | if (chipset.majorVersion <= 9) { |
417 | // Broadcast last value from each row to next row. |
418 | // Use row mask to avoid polluting rows 1 and 3. |
419 | dpp = rewriter.create<amdgpu::DPPOp>( |
420 | loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15, |
421 | rewriter.getUnitAttr(), 0xa, allBanks, |
422 | /*bound_ctrl*/ false); |
423 | res = vector::makeArithReduction( |
424 | rewriter, loc, gpu::convertReductionKind(mode), res, dpp); |
425 | } else if (chipset.majorVersion <= 12) { |
426 | // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2). |
427 | Value uint32Max = rewriter.create<arith::ConstantOp>( |
428 | loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1)); |
429 | dpp = rewriter.create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res, |
430 | uint32Max, uint32Max, |
431 | /*fi=*/true, |
432 | /*bound_ctrl=*/false); |
433 | res = vector::makeArithReduction( |
434 | rewriter, loc, gpu::convertReductionKind(mode), res, dpp); |
435 | } else { |
436 | return rewriter.notifyMatchFailure( |
437 | op, "Subgroup reduce lowering to DPP not currently supported for " |
438 | "this device." ); |
439 | } |
440 | if (ci.subgroupSize == 32) { |
441 | Value lane31 = rewriter.create<arith::ConstantOp>( |
442 | loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31)); |
443 | res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31); |
444 | } |
445 | } |
446 | if (ci.clusterSize >= 64) { |
447 | if (chipset.majorVersion <= 9) { |
448 | // Broadcast 31st lane value to rows 2 and 3. |
449 | dpp = rewriter.create<amdgpu::DPPOp>( |
450 | loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31, |
451 | rewriter.getUnitAttr(), 0xf, allBanks, |
452 | /*bound_ctrl*/ true); |
453 | res = vector::makeArithReduction( |
454 | rewriter, loc, gpu::convertReductionKind(mode), dpp, res); |
455 | // Obtain reduction from last rows, the previous rows are polluted. |
456 | Value lane63 = rewriter.create<arith::ConstantOp>( |
457 | loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63)); |
458 | res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63); |
459 | |
460 | } else if (chipset.majorVersion <= 12) { |
461 | // Assume reduction across 32 lanes has been done. |
462 | // Perform final reduction manually by summing values in lane 0 and |
463 | // lane 32. |
464 | Value lane31 = rewriter.create<arith::ConstantOp>( |
465 | loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31)); |
466 | Value lane63 = rewriter.create<arith::ConstantOp>( |
467 | loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63)); |
468 | lane31 = |
469 | rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane31); |
470 | lane63 = |
471 | rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane63); |
472 | res = vector::makeArithReduction( |
473 | rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63); |
474 | } else { |
475 | return rewriter.notifyMatchFailure( |
476 | op, "Subgroup reduce lowering to DPP not currently supported for " |
477 | "this device." ); |
478 | } |
479 | } |
480 | assert(res.getType() == input.getType()); |
481 | return res; |
482 | } |
483 | |
484 | /// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp` |
485 | /// ops over scalar types. Assumes that the subgroup has |
486 | /// `subgroupSize` lanes. Applicable only to AMD GPUs. |
487 | struct ScalarSubgroupReduceToDPP final |
488 | : OpRewritePattern<gpu::SubgroupReduceOp> { |
489 | ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize, |
490 | bool matchClustered, amdgpu::Chipset chipset, |
491 | PatternBenefit benefit) |
492 | : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), |
493 | matchClustered(matchClustered), chipset(chipset) {} |
494 | |
495 | LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
496 | PatternRewriter &rewriter) const override { |
497 | if (op.getClusterSize().has_value() != matchClustered) { |
498 | return rewriter.notifyMatchFailure( |
499 | op, llvm::formatv(Fmt: "op is {0}clustered but pattern is configured to " |
500 | "only match {1}clustered ops" , |
501 | Vals: matchClustered ? "non-" : "" , |
502 | Vals: matchClustered ? "" : "non-" )); |
503 | } |
504 | auto ci = getAndValidateClusterInfo(op, subgroupSize); |
505 | if (failed(ci)) |
506 | return failure(); |
507 | |
508 | if (ci->clusterStride != 1) |
509 | return rewriter.notifyMatchFailure( |
510 | op, "Subgroup reductions using DPP are currently only available for " |
511 | "clusters of contiguous lanes." ); |
512 | |
513 | Type valueTy = op.getType(); |
514 | if (!valueTy.isIntOrFloat()) |
515 | return rewriter.notifyMatchFailure( |
516 | op, "Value type is not a compatible scalar." ); |
517 | |
518 | FailureOr<Value> dpp = createSubgroupDPPReduction( |
519 | rewriter, op, op.getValue(), op.getOp(), *ci, chipset); |
520 | if (failed(Result: dpp)) |
521 | return failure(); |
522 | |
523 | rewriter.replaceOp(op, dpp.value()); |
524 | return success(); |
525 | } |
526 | |
527 | private: |
528 | unsigned subgroupSize = 0; |
529 | bool matchClustered = false; |
530 | amdgpu::Chipset chipset; |
531 | }; |
532 | } // namespace |
533 | |
534 | void mlir::populateGpuBreakDownSubgroupReducePatterns( |
535 | RewritePatternSet &patterns, unsigned maxShuffleBitwidth, |
536 | PatternBenefit benefit) { |
537 | patterns.add<BreakDownSubgroupReduce>(arg: patterns.getContext(), |
538 | args&: maxShuffleBitwidth, args&: benefit); |
539 | patterns.add<ScalarizeSingleElementReduce>(arg: patterns.getContext(), args&: benefit); |
540 | } |
541 | |
542 | void mlir::populateGpuLowerSubgroupReduceToDPPPatterns( |
543 | RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, |
544 | PatternBenefit benefit) { |
545 | patterns.add<ScalarSubgroupReduceToDPP>(arg: patterns.getContext(), args&: subgroupSize, |
546 | /*matchClustered=*/args: false, args&: chipset, |
547 | args&: benefit); |
548 | } |
549 | |
550 | void mlir::populateGpuLowerClusteredSubgroupReduceToDPPPatterns( |
551 | RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, |
552 | PatternBenefit benefit) { |
553 | patterns.add<ScalarSubgroupReduceToDPP>(arg: patterns.getContext(), args&: subgroupSize, |
554 | /*matchClustered=*/args: true, args&: chipset, |
555 | args&: benefit); |
556 | } |
557 | |
558 | void mlir::populateGpuLowerSubgroupReduceToShufflePatterns( |
559 | RewritePatternSet &patterns, unsigned subgroupSize, |
560 | unsigned shuffleBitwidth, PatternBenefit benefit) { |
561 | patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( |
562 | arg: patterns.getContext(), args&: subgroupSize, args&: shuffleBitwidth, |
563 | /*matchClustered=*/args: false, args&: benefit); |
564 | } |
565 | |
566 | void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns( |
567 | RewritePatternSet &patterns, unsigned subgroupSize, |
568 | unsigned shuffleBitwidth, PatternBenefit benefit) { |
569 | patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( |
570 | arg: patterns.getContext(), args&: subgroupSize, args&: shuffleBitwidth, |
571 | /*matchClustered=*/args: true, args&: benefit); |
572 | } |
573 | |