1 | //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements in-dialect lowering of the all-reduce op to a block of |
10 | // simpler instructions. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/IRMapping.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "llvm/Support/ErrorHandling.h" |
25 | |
26 | using namespace mlir; |
27 | |
28 | namespace { |
29 | |
30 | struct GpuAllReduceRewriter { |
31 | using AccumulatorFactory = std::function<Value(Value, Value)>; |
32 | |
33 | GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp, |
34 | PatternRewriter &rewriter) |
35 | : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter), |
36 | loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()), |
37 | indexType(IndexType::get(reduceOp.getContext())), |
38 | int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} |
39 | |
40 | /// Creates an all_reduce across the workgroup. |
41 | /// |
42 | /// First reduce the elements within a subgroup. The first invocation of each |
43 | /// subgroup writes the intermediate result to workgroup memory. After |
44 | /// synchronizing the workgroup, the first subgroup reduces the values from |
45 | /// workgroup memory. The result is broadcasted to all invocations through |
46 | /// workgroup memory. |
47 | /// |
48 | /// %subgroup_reduce = `createSubgroupReduce(%operand)` |
49 | /// cf.cond_br %is_first_lane, ^then1, ^continue1 |
50 | /// ^then1: |
51 | /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] |
52 | /// cf.br ^continue1 |
53 | /// ^continue1: |
54 | /// gpu.barrier |
55 | /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups |
56 | /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2 |
57 | /// ^then2: |
58 | /// %partial_reduce = load %workgroup_buffer[%invocation_idx] |
59 | /// %all_reduce = `createSubgroupReduce(%partial_reduce)` |
60 | /// store %all_reduce, %workgroup_buffer[%zero] |
61 | /// llvm.br ^continue2 |
62 | /// ^continue2: |
63 | /// gpu.barrier |
64 | /// %result = load %workgroup_buffer[%zero] |
65 | /// return %result |
66 | /// |
67 | void rewrite() { |
68 | rewriter.setInsertionPoint(reduceOp); |
69 | |
70 | // Compute linear invocation index and workgroup size. |
71 | Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x); |
72 | Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y); |
73 | Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z); |
74 | Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x); |
75 | Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y); |
76 | Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z); |
77 | Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY); |
78 | Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY); |
79 | Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX); |
80 | Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY); |
81 | Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX); |
82 | Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ); |
83 | |
84 | // Compute lane id (invocation id withing the subgroup). |
85 | Value subgroupMask = |
86 | create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type); |
87 | Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask); |
88 | Value isFirstLane = |
89 | create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId, |
90 | create<arith::ConstantIntOp>(0, int32Type)); |
91 | |
92 | Value numThreadsWithSmallerSubgroupId = |
93 | create<arith::SubIOp>(invocationIdx, laneId); |
94 | // The number of active invocations starting from the current subgroup. |
95 | // The consumers do not require the value to be clamped to the size of the |
96 | // subgroup. |
97 | Value activeWidth = |
98 | create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); |
99 | |
100 | // Create factory for op which accumulates to values. |
101 | AccumulatorFactory accumFactory = getFactory(); |
102 | assert(accumFactory && "failed to create accumulator factory" ); |
103 | |
104 | // Reduce elements within each subgroup to produce the intermediate results. |
105 | Value subgroupReduce = createSubgroupReduce( |
106 | activeWidth, laneId, reduceOp.getValue(), accumFactory); |
107 | |
108 | // Add workgroup buffer to parent function for intermediate result. |
109 | Value buffer = createWorkgroupBuffer(); |
110 | |
111 | // Write the intermediate results to workgroup memory, using the first lane |
112 | // of each subgroup. |
113 | createPredicatedBlock(condition: isFirstLane, predicatedOpsFactory: [&] { |
114 | Value subgroupId = getDivideBySubgroupSize(value: invocationIdx); |
115 | Value index = create<arith::IndexCastOp>(indexType, subgroupId); |
116 | create<memref::StoreOp>(subgroupReduce, buffer, index); |
117 | }); |
118 | create<gpu::BarrierOp>(); |
119 | |
120 | // Compute number of active subgroups. |
121 | Value biasedBlockSize = |
122 | create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask); |
123 | Value numSubgroups = getDivideBySubgroupSize(value: biasedBlockSize); |
124 | Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, |
125 | invocationIdx, numSubgroups); |
126 | |
127 | // Use the first numSubgroups invocations to reduce the intermediate results |
128 | // from workgroup memory. The final result is written to workgroup memory |
129 | // again. |
130 | Value zero = create<arith::ConstantIndexOp>(args: 0); |
131 | createPredicatedBlock(condition: isValidSubgroup, predicatedOpsFactory: [&] { |
132 | Value index = create<arith::IndexCastOp>(indexType, invocationIdx); |
133 | Value value = create<memref::LoadOp>(valueType, buffer, index); |
134 | Value result = |
135 | createSubgroupReduce(activeWidth: numSubgroups, laneId, operand: value, accumFactory); |
136 | create<memref::StoreOp>(result, buffer, zero); |
137 | }); |
138 | |
139 | // Synchronize workgroup and load result from workgroup memory. |
140 | create<gpu::BarrierOp>(); |
141 | Value result = create<memref::LoadOp>(valueType, buffer, zero); |
142 | |
143 | rewriter.replaceOp(reduceOp, result); |
144 | } |
145 | |
146 | private: |
147 | // Shortcut to create an op from rewriter using loc as the first argument. |
148 | template <typename T, typename... Args> |
149 | T create(Args... args) { |
150 | return rewriter.create<T>(loc, std::forward<Args>(args)...); |
151 | } |
152 | |
153 | // Creates dimension op of type T, with the result casted to int32. |
154 | template <typename T> |
155 | Value getDimOp(gpu::Dimension dimension) { |
156 | Value dim = create<T>(indexType, dimension); |
157 | return create<arith::IndexCastOp>(int32Type, dim); |
158 | } |
159 | |
160 | /// Adds type to funcOp's workgroup attributions. |
161 | Value createWorkgroupBuffer() { |
162 | // TODO: Pick a proper location for the attribution. |
163 | auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get( |
164 | funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); |
165 | auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{}, |
166 | workgroupMemoryAddressSpace); |
167 | return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc()); |
168 | } |
169 | |
170 | /// Returns an accumulator factory using either the op attribute or the body |
171 | /// region. |
172 | AccumulatorFactory getFactory() { |
173 | auto &body = reduceOp.getBody(); |
174 | if (!body.empty()) |
175 | return getFactory(body); |
176 | auto opAttr = reduceOp.getOp(); |
177 | if (opAttr) |
178 | return getFactory(*opAttr); |
179 | return AccumulatorFactory(); |
180 | } |
181 | |
182 | /// Returns an accumulator factory that clones the body. The body's entry |
183 | /// block is expected to have 2 arguments. The gpu.yield return the |
184 | /// accumulated value of the same type. |
185 | AccumulatorFactory getFactory(Region &body) { |
186 | return [&body, this](Value lhs, Value rhs) -> Value { |
187 | Block *block = rewriter.getInsertionBlock(); |
188 | Block *split = rewriter.splitBlock(block, before: rewriter.getInsertionPoint()); |
189 | |
190 | // Insert accumulator body between split block. |
191 | IRMapping mapping; |
192 | mapping.map(from: body.getArgument(i: 0), to: lhs); |
193 | mapping.map(from: body.getArgument(i: 1), to: rhs); |
194 | rewriter.cloneRegionBefore(region&: body, parent&: *split->getParent(), |
195 | before: split->getIterator(), mapping); |
196 | |
197 | // Add branch before inserted body, into body. |
198 | block = block->getNextNode(); |
199 | create<cf::BranchOp>(block, ValueRange()); |
200 | |
201 | // Replace all gpu.yield ops with branch out of body. |
202 | for (; block != split; block = block->getNextNode()) { |
203 | Operation *terminator = block->getTerminator(); |
204 | if (!isa<gpu::YieldOp>(terminator)) |
205 | continue; |
206 | rewriter.setInsertionPointToEnd(block); |
207 | rewriter.replaceOpWithNewOp<cf::BranchOp>( |
208 | terminator, split, ValueRange(terminator->getOperand(0))); |
209 | } |
210 | |
211 | // Return accumulator result. |
212 | rewriter.setInsertionPointToStart(split); |
213 | return split->addArgument(type: lhs.getType(), loc: lhs.getLoc()); |
214 | }; |
215 | } |
216 | |
217 | /// Returns an accumulator factory that creates an op specified by opName. |
218 | AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { |
219 | return [opName, this](Value lhs, Value rhs) { |
220 | return vector::makeArithReduction(rewriter, loc, |
221 | convertReductionKind(opName), lhs, rhs); |
222 | }; |
223 | } |
224 | |
225 | /// Creates an if-block skeleton and calls the two factories to generate the |
226 | /// ops in the `then` and `else` block.. |
227 | /// |
228 | /// llvm.cond_br %condition, ^then, ^continue |
229 | /// ^then: |
230 | /// %then_operands = `thenOpsFactory()` |
231 | /// llvm.br ^continue(%then_operands) |
232 | /// ^else: |
233 | /// %else_operands = `elseOpsFactory()` |
234 | /// llvm.br ^continue(%else_operands) |
235 | /// ^continue(%block_operands): |
236 | /// |
237 | template <typename ThenOpsFactory, typename ElseOpsFactory> |
238 | void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, |
239 | ElseOpsFactory &&elseOpsFactory) { |
240 | Block *currentBlock = rewriter.getInsertionBlock(); |
241 | auto currentPoint = rewriter.getInsertionPoint(); |
242 | |
243 | Block *thenBlock = rewriter.splitBlock(block: currentBlock, before: currentPoint); |
244 | Block *elseBlock = rewriter.splitBlock(block: thenBlock, before: thenBlock->begin()); |
245 | Block *continueBlock = rewriter.splitBlock(block: elseBlock, before: elseBlock->begin()); |
246 | |
247 | rewriter.setInsertionPointToEnd(currentBlock); |
248 | create<cf::CondBranchOp>(condition, thenBlock, |
249 | /*trueOperands=*/ArrayRef<Value>(), elseBlock, |
250 | /*falseOperands=*/ArrayRef<Value>()); |
251 | |
252 | rewriter.setInsertionPointToStart(thenBlock); |
253 | auto thenOperands = thenOpsFactory(); |
254 | create<cf::BranchOp>(continueBlock, thenOperands); |
255 | |
256 | rewriter.setInsertionPointToStart(elseBlock); |
257 | auto elseOperands = elseOpsFactory(); |
258 | create<cf::BranchOp>(continueBlock, elseOperands); |
259 | |
260 | assert(thenOperands.size() == elseOperands.size()); |
261 | rewriter.setInsertionPointToStart(continueBlock); |
262 | for (auto operand : thenOperands) |
263 | continueBlock->addArgument(type: operand.getType(), loc: operand.getLoc()); |
264 | } |
265 | |
266 | /// Shortcut for createIf with empty else block and no block operands. |
267 | template <typename Factory> |
268 | void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { |
269 | static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, |
270 | "predicatedOpsFactory should not return any value" ); |
271 | createIf( |
272 | condition, |
273 | [&] { |
274 | predicatedOpsFactory(); |
275 | return ArrayRef<Value>(); |
276 | }, |
277 | [&] { return ArrayRef<Value>(); }); |
278 | } |
279 | |
280 | /// Creates a reduction across the first activeWidth lanes of a subgroup, or |
281 | /// the entire subgroup if activeWidth is larger than the subgroup width. |
282 | /// The first lane returns the result, all others return values are undefined. |
283 | Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, |
284 | AccumulatorFactory &accumFactory) { |
285 | Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); |
286 | Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, |
287 | activeWidth, subgroupSize); |
288 | std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; |
289 | |
290 | createIf( |
291 | condition: isPartialSubgroup, |
292 | // Generate reduction over a (potentially) partial subgroup. |
293 | thenOpsFactory: [&] { |
294 | Value value = operand; |
295 | // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source |
296 | // lane is within the active range. The accumulated value is available |
297 | // in the first lane. |
298 | for (int i = 1; i < kSubgroupSize; i <<= 1) { |
299 | Value offset = create<arith::ConstantIntOp>(i, int32Type); |
300 | auto shuffleOp = create<gpu::ShuffleOp>( |
301 | shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR); |
302 | // Skip the accumulation if the shuffle op read from a lane outside |
303 | // of the active range. |
304 | createIf( |
305 | shuffleOp.getResult(1), |
306 | [&] { |
307 | return SmallVector<Value, 1>{ |
308 | accumFactory(value, shuffleOp.getResult(0))}; |
309 | }, |
310 | [&] { return llvm::ArrayRef(value); }); |
311 | value = rewriter.getInsertionBlock()->getArgument(i: 0); |
312 | } |
313 | return SmallVector<Value, 1>{value}; |
314 | }, |
315 | // Generate a reduction over the entire subgroup. This is a |
316 | // specialization of the above reduction with unconditional |
317 | // accumulation. |
318 | elseOpsFactory: [&] { |
319 | Value value = operand; |
320 | for (int i = 1; i < kSubgroupSize; i <<= 1) { |
321 | Value offset = create<arith::ConstantIntOp>(i, int32Type); |
322 | auto shuffleOp = |
323 | create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize, |
324 | gpu::ShuffleMode::XOR); |
325 | value = accumFactory(value, shuffleOp.getResult(0)); |
326 | } |
327 | return SmallVector<Value, 1>{value}; |
328 | }); |
329 | return rewriter.getInsertionBlock()->getArgument(i: 0); |
330 | } |
331 | |
332 | /// Returns value divided by the subgroup size (i.e. 32). |
333 | Value getDivideBySubgroupSize(Value value) { |
334 | Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); |
335 | return create<arith::DivSIOp>(int32Type, value, subgroupSize); |
336 | } |
337 | |
338 | gpu::GPUFuncOp funcOp; |
339 | gpu::AllReduceOp reduceOp; |
340 | PatternRewriter &rewriter; |
341 | |
342 | Location loc; |
343 | Type valueType; |
344 | Type indexType; |
345 | IntegerType int32Type; |
346 | |
347 | static constexpr int kSubgroupSize = 32; |
348 | }; |
349 | |
350 | struct GpuAllReduceRewrite : public RewritePattern { |
351 | explicit GpuAllReduceRewrite(MLIRContext *context) |
352 | : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} |
353 | |
354 | LogicalResult matchAndRewrite(Operation *op, |
355 | PatternRewriter &rewriter) const override { |
356 | auto funcOp = cast<gpu::GPUFuncOp>(op); |
357 | |
358 | SmallVector<gpu::AllReduceOp> reduceOps; |
359 | auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult { |
360 | if (!reduceOp.getUniform()) |
361 | return WalkResult::interrupt(); |
362 | |
363 | reduceOps.emplace_back(reduceOp); |
364 | return WalkResult::advance(); |
365 | }; |
366 | |
367 | if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty()) |
368 | return rewriter.notifyMatchFailure( |
369 | arg&: op, msg: "Non uniform reductions are not supported yet." ); |
370 | |
371 | for (gpu::AllReduceOp reduceOp : reduceOps) |
372 | GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); |
373 | |
374 | return success(); |
375 | } |
376 | }; |
377 | } // namespace |
378 | |
379 | void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { |
380 | patterns.add<GpuAllReduceRewrite>(arg: patterns.getContext()); |
381 | } |
382 | |