1//===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect lowering of the all-reduce op to a block of
10// simpler instructions.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/GPU/Transforms/Passes.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/Vector/IR/VectorOps.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/PatternMatch.h"
23
24using namespace mlir;
25
26namespace {
27
28struct GpuAllReduceRewriter {
29 using AccumulatorFactory = std::function<Value(Value, Value)>;
30
31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
32 PatternRewriter &rewriter)
33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
34 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
35 indexType(IndexType::get(context: reduceOp.getContext())),
36 int32Type(IntegerType::get(context: reduceOp.getContext(), /*width=*/32)) {}
37
38 /// Creates an all_reduce across the workgroup.
39 ///
40 /// First reduce the elements within a subgroup. The first invocation of each
41 /// subgroup writes the intermediate result to workgroup memory. After
42 /// synchronizing the workgroup, the first subgroup reduces the values from
43 /// workgroup memory. The result is broadcasted to all invocations through
44 /// workgroup memory.
45 ///
46 /// %subgroup_reduce = `createSubgroupReduce(%operand)`
47 /// cf.cond_br %is_first_lane, ^then1, ^continue1
48 /// ^then1:
49 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
50 /// cf.br ^continue1
51 /// ^continue1:
52 /// gpu.barrier
53 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
54 /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
55 /// ^then2:
56 /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
57 /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
58 /// store %all_reduce, %workgroup_buffer[%zero]
59 /// llvm.br ^continue2
60 /// ^continue2:
61 /// gpu.barrier
62 /// %result = load %workgroup_buffer[%zero]
63 /// return %result
64 ///
65 void rewrite() {
66 rewriter.setInsertionPoint(reduceOp);
67
68 // Compute linear invocation index and workgroup size.
69 Value dimX = getDimOp<gpu::BlockDimOp>(dimension: gpu::Dimension::x);
70 Value dimY = getDimOp<gpu::BlockDimOp>(dimension: gpu::Dimension::y);
71 Value dimZ = getDimOp<gpu::BlockDimOp>(dimension: gpu::Dimension::z);
72 Value tidX = getDimOp<gpu::ThreadIdOp>(dimension: gpu::Dimension::x);
73 Value tidY = getDimOp<gpu::ThreadIdOp>(dimension: gpu::Dimension::y);
74 Value tidZ = getDimOp<gpu::ThreadIdOp>(dimension: gpu::Dimension::z);
75 Value tmp1 = create<arith::MulIOp>(args: int32Type, args: tidZ, args: dimY);
76 Value tmp2 = create<arith::AddIOp>(args: int32Type, args: tmp1, args: tidY);
77 Value tmp3 = create<arith::MulIOp>(args: int32Type, args: tmp2, args: dimX);
78 Value tmp4 = create<arith::MulIOp>(args: int32Type, args: dimX, args: dimY);
79 Value invocationIdx = create<arith::AddIOp>(args: int32Type, args: tmp3, args: tidX);
80 Value workgroupSize = create<arith::MulIOp>(args: int32Type, args: tmp4, args: dimZ);
81
82 // Compute lane id (invocation id withing the subgroup).
83 Value subgroupMask =
84 create<arith::ConstantIntOp>(args: int32Type, args: kSubgroupSize - 1);
85 Value laneId = create<arith::AndIOp>(args: invocationIdx, args: subgroupMask);
86 Value isFirstLane =
87 create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args: laneId,
88 args: create<arith::ConstantIntOp>(args: int32Type, args: 0));
89
90 Value numThreadsWithSmallerSubgroupId =
91 create<arith::SubIOp>(args: invocationIdx, args: laneId);
92 // The number of active invocations starting from the current subgroup.
93 // The consumers do not require the value to be clamped to the size of the
94 // subgroup.
95 Value activeWidth =
96 create<arith::SubIOp>(args: workgroupSize, args: numThreadsWithSmallerSubgroupId);
97
98 // Create factory for op which accumulates to values.
99 AccumulatorFactory accumFactory = getFactory();
100 assert(accumFactory && "failed to create accumulator factory");
101
102 // Reduce elements within each subgroup to produce the intermediate results.
103 Value subgroupReduce = createSubgroupReduce(
104 activeWidth, laneId, operand: reduceOp.getValue(), accumFactory);
105
106 // Add workgroup buffer to parent function for intermediate result.
107 Value buffer = createWorkgroupBuffer();
108
109 // Write the intermediate results to workgroup memory, using the first lane
110 // of each subgroup.
111 createPredicatedBlock(condition: isFirstLane, predicatedOpsFactory: [&] {
112 Value subgroupId = getDivideBySubgroupSize(value: invocationIdx);
113 Value index = create<arith::IndexCastOp>(args: indexType, args: subgroupId);
114 create<memref::StoreOp>(args: subgroupReduce, args: buffer, args: index);
115 });
116 create<gpu::BarrierOp>();
117
118 // Compute number of active subgroups.
119 Value biasedBlockSize =
120 create<arith::AddIOp>(args: int32Type, args: workgroupSize, args: subgroupMask);
121 Value numSubgroups = getDivideBySubgroupSize(value: biasedBlockSize);
122 Value isValidSubgroup = create<arith::CmpIOp>(args: arith::CmpIPredicate::slt,
123 args: invocationIdx, args: numSubgroups);
124
125 // Use the first numSubgroups invocations to reduce the intermediate results
126 // from workgroup memory. The final result is written to workgroup memory
127 // again.
128 Value zero = create<arith::ConstantIndexOp>(args: 0);
129 createPredicatedBlock(condition: isValidSubgroup, predicatedOpsFactory: [&] {
130 Value index = create<arith::IndexCastOp>(args: indexType, args: invocationIdx);
131 Value value = create<memref::LoadOp>(args: valueType, args: buffer, args: index);
132 Value result =
133 createSubgroupReduce(activeWidth: numSubgroups, laneId, operand: value, accumFactory);
134 create<memref::StoreOp>(args: result, args: buffer, args: zero);
135 });
136
137 // Synchronize workgroup and load result from workgroup memory.
138 create<gpu::BarrierOp>();
139 Value result = create<memref::LoadOp>(args: valueType, args: buffer, args: zero);
140
141 rewriter.replaceOp(op: reduceOp, newValues: result);
142 }
143
144private:
145 // Shortcut to create an op from rewriter using loc as the first argument.
146 template <typename T, typename... Args>
147 T create(Args... args) {
148 return rewriter.create<T>(loc, std::forward<Args>(args)...);
149 }
150
151 // Creates dimension op of type T, with the result casted to int32.
152 template <typename T>
153 Value getDimOp(gpu::Dimension dimension) {
154 Value dim = create<T>(indexType, dimension);
155 return create<arith::IndexCastOp>(args: int32Type, args: dim);
156 }
157
158 /// Adds type to funcOp's workgroup attributions.
159 Value createWorkgroupBuffer() {
160 // TODO: Pick a proper location for the attribution.
161 auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
162 context: funcOp->getContext(), value: gpu::GPUDialect::getWorkgroupAddressSpace());
163 auto bufferType = MemRefType::get(shape: {kSubgroupSize}, elementType: valueType, map: AffineMap{},
164 memorySpace: workgroupMemoryAddressSpace);
165 return funcOp.addWorkgroupAttribution(type: bufferType, loc: rewriter.getUnknownLoc());
166 }
167
168 /// Returns an accumulator factory using either the op attribute or the body
169 /// region.
170 AccumulatorFactory getFactory() {
171 auto &body = reduceOp.getBody();
172 if (!body.empty())
173 return getFactory(body);
174 auto opAttr = reduceOp.getOp();
175 if (opAttr)
176 return getFactory(opName: *opAttr);
177 return AccumulatorFactory();
178 }
179
180 /// Returns an accumulator factory that clones the body. The body's entry
181 /// block is expected to have 2 arguments. The gpu.yield return the
182 /// accumulated value of the same type.
183 AccumulatorFactory getFactory(Region &body) {
184 return [&body, this](Value lhs, Value rhs) -> Value {
185 Block *block = rewriter.getInsertionBlock();
186 Block *split = rewriter.splitBlock(block, before: rewriter.getInsertionPoint());
187
188 // Insert accumulator body between split block.
189 IRMapping mapping;
190 mapping.map(from: body.getArgument(i: 0), to: lhs);
191 mapping.map(from: body.getArgument(i: 1), to: rhs);
192 rewriter.cloneRegionBefore(region&: body, parent&: *split->getParent(),
193 before: split->getIterator(), mapping);
194
195 // Add branch before inserted body, into body.
196 block = block->getNextNode();
197 create<cf::BranchOp>(args: block, args: ValueRange());
198
199 // Replace all gpu.yield ops with branch out of body.
200 for (; block != split; block = block->getNextNode()) {
201 Operation *terminator = block->getTerminator();
202 if (!isa<gpu::YieldOp>(Val: terminator))
203 continue;
204 rewriter.setInsertionPointToEnd(block);
205 rewriter.replaceOpWithNewOp<cf::BranchOp>(
206 op: terminator, args&: split, args: ValueRange(terminator->getOperand(idx: 0)));
207 }
208
209 // Return accumulator result.
210 rewriter.setInsertionPointToStart(split);
211 return split->addArgument(type: lhs.getType(), loc: lhs.getLoc());
212 };
213 }
214
215 /// Returns an accumulator factory that creates an op specified by opName.
216 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217 return [opName, this](Value lhs, Value rhs) {
218 return vector::makeArithReduction(b&: rewriter, loc,
219 kind: convertReductionKind(mode: opName), v1: lhs, acc: rhs);
220 };
221 }
222
223 /// Creates an if-block skeleton and calls the two factories to generate the
224 /// ops in the `then` and `else` block..
225 ///
226 /// llvm.cond_br %condition, ^then, ^continue
227 /// ^then:
228 /// %then_operands = `thenOpsFactory()`
229 /// llvm.br ^continue(%then_operands)
230 /// ^else:
231 /// %else_operands = `elseOpsFactory()`
232 /// llvm.br ^continue(%else_operands)
233 /// ^continue(%block_operands):
234 ///
235 template <typename ThenOpsFactory, typename ElseOpsFactory>
236 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
237 ElseOpsFactory &&elseOpsFactory) {
238 Block *currentBlock = rewriter.getInsertionBlock();
239 auto currentPoint = rewriter.getInsertionPoint();
240
241 Block *thenBlock = rewriter.splitBlock(block: currentBlock, before: currentPoint);
242 Block *elseBlock = rewriter.splitBlock(block: thenBlock, before: thenBlock->begin());
243 Block *continueBlock = rewriter.splitBlock(block: elseBlock, before: elseBlock->begin());
244
245 rewriter.setInsertionPointToEnd(currentBlock);
246 create<cf::CondBranchOp>(args: condition, args: thenBlock,
247 /*trueOperands=*/args: ArrayRef<Value>(), args: elseBlock,
248 /*falseOperands=*/args: ArrayRef<Value>());
249
250 rewriter.setInsertionPointToStart(thenBlock);
251 auto thenOperands = thenOpsFactory();
252 create<cf::BranchOp>(continueBlock, thenOperands);
253
254 rewriter.setInsertionPointToStart(elseBlock);
255 auto elseOperands = elseOpsFactory();
256 create<cf::BranchOp>(continueBlock, elseOperands);
257
258 assert(thenOperands.size() == elseOperands.size());
259 rewriter.setInsertionPointToStart(continueBlock);
260 for (auto operand : thenOperands)
261 continueBlock->addArgument(type: operand.getType(), loc: operand.getLoc());
262 }
263
264 /// Shortcut for createIf with empty else block and no block operands.
265 template <typename Factory>
266 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
267 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
268 "predicatedOpsFactory should not return any value");
269 createIf(
270 condition,
271 [&] {
272 predicatedOpsFactory();
273 return ArrayRef<Value>();
274 },
275 [&] { return ArrayRef<Value>(); });
276 }
277
278 /// Creates a reduction across the first activeWidth lanes of a subgroup, or
279 /// the entire subgroup if activeWidth is larger than the subgroup width.
280 /// The first lane returns the result, all others return values are undefined.
281 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
282 AccumulatorFactory &accumFactory) {
283 Value subgroupSize = create<arith::ConstantIntOp>(args: int32Type, args: kSubgroupSize);
284 Value isPartialSubgroup = create<arith::CmpIOp>(args: arith::CmpIPredicate::slt,
285 args: activeWidth, args: subgroupSize);
286 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
287
288 createIf(
289 condition: isPartialSubgroup,
290 // Generate reduction over a (potentially) partial subgroup.
291 thenOpsFactory: [&] {
292 Value value = operand;
293 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
294 // lane is within the active range. The accumulated value is available
295 // in the first lane.
296 for (int i = 1; i < kSubgroupSize; i <<= 1) {
297 Value offset = create<arith::ConstantIntOp>(args: int32Type, args: i);
298 auto shuffleOp = create<gpu::ShuffleOp>(
299 args: shuffleType, args: value, args: offset, args: activeWidth, args: gpu::ShuffleMode::XOR);
300 // Skip the accumulation if the shuffle op read from a lane outside
301 // of the active range.
302 createIf(
303 condition: shuffleOp.getResult(i: 1),
304 thenOpsFactory: [&] {
305 return SmallVector<Value, 1>{
306 accumFactory(value, shuffleOp.getResult(i: 0))};
307 },
308 elseOpsFactory: [&] { return llvm::ArrayRef(value); });
309 value = rewriter.getInsertionBlock()->getArgument(i: 0);
310 }
311 return SmallVector<Value, 1>{value};
312 },
313 // Generate a reduction over the entire subgroup. This is a
314 // specialization of the above reduction with unconditional
315 // accumulation.
316 elseOpsFactory: [&] {
317 Value value = operand;
318 for (int i = 1; i < kSubgroupSize; i <<= 1) {
319 Value offset = create<arith::ConstantIntOp>(args: int32Type, args: i);
320 auto shuffleOp =
321 create<gpu::ShuffleOp>(args: shuffleType, args: value, args: offset, args: subgroupSize,
322 args: gpu::ShuffleMode::XOR);
323 value = accumFactory(value, shuffleOp.getResult(i: 0));
324 }
325 return SmallVector<Value, 1>{value};
326 });
327 return rewriter.getInsertionBlock()->getArgument(i: 0);
328 }
329
330 /// Returns value divided by the subgroup size (i.e. 32).
331 Value getDivideBySubgroupSize(Value value) {
332 Value subgroupSize = create<arith::ConstantIntOp>(args: int32Type, args: kSubgroupSize);
333 return create<arith::DivSIOp>(args: int32Type, args: value, args: subgroupSize);
334 }
335
336 gpu::GPUFuncOp funcOp;
337 gpu::AllReduceOp reduceOp;
338 PatternRewriter &rewriter;
339
340 Location loc;
341 Type valueType;
342 Type indexType;
343 IntegerType int32Type;
344
345 static constexpr int kSubgroupSize = 32;
346};
347
348struct GpuAllReduceRewrite : public RewritePattern {
349 explicit GpuAllReduceRewrite(MLIRContext *context)
350 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
351
352 LogicalResult matchAndRewrite(Operation *op,
353 PatternRewriter &rewriter) const override {
354 auto funcOp = cast<gpu::GPUFuncOp>(Val: op);
355
356 SmallVector<gpu::AllReduceOp> reduceOps;
357 auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
358 if (!reduceOp.getUniform())
359 return WalkResult::interrupt();
360
361 reduceOps.emplace_back(Args&: reduceOp);
362 return WalkResult::advance();
363 };
364
365 if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
366 return rewriter.notifyMatchFailure(
367 arg&: op, msg: "Non uniform reductions are not supported yet.");
368
369 for (gpu::AllReduceOp reduceOp : reduceOps)
370 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
371
372 return success();
373 }
374};
375} // namespace
376
377void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
378 patterns.add<GpuAllReduceRewrite>(arg: patterns.getContext());
379}
380

source code of mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp