1 | //===- DistributionUtils.cpp - Distribution tools for GPUOps --------------===// |
2 | // |
3 | // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements distribution utility methods. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/IR/Value.h" |
17 | |
18 | #include <numeric> |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::gpu; |
22 | |
23 | WarpExecuteOnLane0Op |
24 | WarpDistributionPattern::moveRegionToNewWarpOpAndReplaceReturns( |
25 | RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, |
26 | ValueRange newYieldedValues, TypeRange newReturnTypes) const { |
27 | // Create a new op before the existing one, with the extra operands. |
28 | OpBuilder::InsertionGuard g(rewriter); |
29 | rewriter.setInsertionPoint(warpOp); |
30 | auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( |
31 | warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), |
32 | warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); |
33 | |
34 | Region &opBody = warpOp.getBodyRegion(); |
35 | Region &newOpBody = newWarpOp.getBodyRegion(); |
36 | Block &newOpFirstBlock = newOpBody.front(); |
37 | rewriter.inlineRegionBefore(region&: opBody, parent&: newOpBody, before: newOpBody.begin()); |
38 | rewriter.eraseBlock(block: &newOpFirstBlock); |
39 | assert(newWarpOp.getWarpRegion().hasOneBlock() && |
40 | "expected WarpOp with single block" ); |
41 | |
42 | auto yield = |
43 | cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); |
44 | |
45 | rewriter.modifyOpInPlace( |
46 | yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); }); |
47 | return newWarpOp; |
48 | } |
49 | |
50 | WarpExecuteOnLane0Op |
51 | WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns( |
52 | RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, |
53 | ValueRange newYieldedValues, TypeRange newReturnTypes, |
54 | SmallVector<size_t> &indices) const { |
55 | SmallVector<Type> types(warpOp.getResultTypes().begin(), |
56 | warpOp.getResultTypes().end()); |
57 | auto yield = cast<gpu::YieldOp>( |
58 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
59 | llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), |
60 | yield.getOperands().end()); |
61 | for (auto [value, type] : llvm::zip_equal(t&: newYieldedValues, u&: newReturnTypes)) { |
62 | if (yieldValues.insert(X: value)) { |
63 | types.push_back(Elt: type); |
64 | indices.push_back(Elt: yieldValues.size() - 1); |
65 | } else { |
66 | // If the value already exit the region don't create a new output. |
67 | for (auto [idx, yieldOperand] : |
68 | llvm::enumerate(yieldValues.getArrayRef())) { |
69 | if (yieldOperand == value) { |
70 | indices.push_back(idx); |
71 | break; |
72 | } |
73 | } |
74 | } |
75 | } |
76 | yieldValues.insert_range(R&: newYieldedValues); |
77 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( |
78 | rewriter, warpOp, yieldValues.getArrayRef(), types); |
79 | rewriter.replaceOp(warpOp, |
80 | newWarpOp.getResults().take_front(warpOp.getNumResults())); |
81 | return newWarpOp; |
82 | } |
83 | |
84 | OpOperand *WarpDistributionPattern::getWarpResult( |
85 | WarpExecuteOnLane0Op warpOp, |
86 | llvm::function_ref<bool(Operation *)> fn) const { |
87 | auto yield = cast<gpu::YieldOp>( |
88 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
89 | for (OpOperand &yieldOperand : yield->getOpOperands()) { |
90 | Value yieldValues = yieldOperand.get(); |
91 | Operation *definedOp = yieldValues.getDefiningOp(); |
92 | if (definedOp && fn(definedOp)) { |
93 | if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) |
94 | return &yieldOperand; |
95 | } |
96 | } |
97 | return nullptr; |
98 | } |
99 | |
100 | bool WarpDistributionPattern::delinearizeLaneId( |
101 | OpBuilder &builder, Location loc, ArrayRef<int64_t> originalShape, |
102 | ArrayRef<int64_t> distributedShape, int64_t warpSize, Value laneId, |
103 | SmallVectorImpl<Value> &delinearizedIds) const { |
104 | // If the original shape and the distributed shape is the same, we don't |
105 | // distribute at all--every thread is handling the whole. For such case, we |
106 | // should not rely on lane IDs later. So just return an empty lane ID vector. |
107 | if (originalShape == distributedShape) { |
108 | delinearizedIds.clear(); |
109 | return true; |
110 | } |
111 | |
112 | SmallVector<int64_t> sizes; |
113 | for (auto [large, small] : llvm::zip_equal(t&: originalShape, u&: distributedShape)) { |
114 | if (large % small != 0) |
115 | return false; |
116 | sizes.push_back(Elt: large / small); |
117 | } |
118 | if (std::accumulate(first: sizes.begin(), last: sizes.end(), init: 1, |
119 | binary_op: std::multiplies<int64_t>()) != warpSize) |
120 | return false; |
121 | |
122 | AffineExpr s0, s1; |
123 | bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1); |
124 | |
125 | int64_t usedThreads = 1; |
126 | |
127 | Value zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0); |
128 | delinearizedIds.assign(NumElts: sizes.size(), Elt: zero); |
129 | |
130 | for (int i = sizes.size() - 1; i >= 0; --i) { |
131 | usedThreads *= sizes[i]; |
132 | if (usedThreads == warpSize) { |
133 | // We've used up all available threads. Don't need to perform modulo |
134 | // anymore. And we can stop the calculation for further dimensions. |
135 | delinearizedIds[i] = laneId; |
136 | break; |
137 | } |
138 | delinearizedIds[i] = |
139 | affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId}); |
140 | laneId = affine::makeComposedAffineApply( |
141 | builder, loc, s0.floorDiv(v: usedThreads), {laneId}); |
142 | } |
143 | return true; |
144 | } |
145 | |