| 1 | //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements in-dialect lowering of the all-reduce op to a block of |
| 10 | // simpler instructions. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 17 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 19 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/IR/IRMapping.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | |
| 24 | using namespace mlir; |
| 25 | |
| 26 | namespace { |
| 27 | |
| 28 | struct GpuAllReduceRewriter { |
| 29 | using AccumulatorFactory = std::function<Value(Value, Value)>; |
| 30 | |
| 31 | GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp, |
| 32 | PatternRewriter &rewriter) |
| 33 | : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter), |
| 34 | loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()), |
| 35 | indexType(IndexType::get(context: reduceOp.getContext())), |
| 36 | int32Type(IntegerType::get(context: reduceOp.getContext(), /*width=*/32)) {} |
| 37 | |
| 38 | /// Creates an all_reduce across the workgroup. |
| 39 | /// |
| 40 | /// First reduce the elements within a subgroup. The first invocation of each |
| 41 | /// subgroup writes the intermediate result to workgroup memory. After |
| 42 | /// synchronizing the workgroup, the first subgroup reduces the values from |
| 43 | /// workgroup memory. The result is broadcasted to all invocations through |
| 44 | /// workgroup memory. |
| 45 | /// |
| 46 | /// %subgroup_reduce = `createSubgroupReduce(%operand)` |
| 47 | /// cf.cond_br %is_first_lane, ^then1, ^continue1 |
| 48 | /// ^then1: |
| 49 | /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] |
| 50 | /// cf.br ^continue1 |
| 51 | /// ^continue1: |
| 52 | /// gpu.barrier |
| 53 | /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups |
| 54 | /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2 |
| 55 | /// ^then2: |
| 56 | /// %partial_reduce = load %workgroup_buffer[%invocation_idx] |
| 57 | /// %all_reduce = `createSubgroupReduce(%partial_reduce)` |
| 58 | /// store %all_reduce, %workgroup_buffer[%zero] |
| 59 | /// llvm.br ^continue2 |
| 60 | /// ^continue2: |
| 61 | /// gpu.barrier |
| 62 | /// %result = load %workgroup_buffer[%zero] |
| 63 | /// return %result |
| 64 | /// |
| 65 | void rewrite() { |
| 66 | rewriter.setInsertionPoint(reduceOp); |
| 67 | |
| 68 | // Compute linear invocation index and workgroup size. |
| 69 | Value dimX = getDimOp<gpu::BlockDimOp>(dimension: gpu::Dimension::x); |
| 70 | Value dimY = getDimOp<gpu::BlockDimOp>(dimension: gpu::Dimension::y); |
| 71 | Value dimZ = getDimOp<gpu::BlockDimOp>(dimension: gpu::Dimension::z); |
| 72 | Value tidX = getDimOp<gpu::ThreadIdOp>(dimension: gpu::Dimension::x); |
| 73 | Value tidY = getDimOp<gpu::ThreadIdOp>(dimension: gpu::Dimension::y); |
| 74 | Value tidZ = getDimOp<gpu::ThreadIdOp>(dimension: gpu::Dimension::z); |
| 75 | Value tmp1 = create<arith::MulIOp>(args: int32Type, args: tidZ, args: dimY); |
| 76 | Value tmp2 = create<arith::AddIOp>(args: int32Type, args: tmp1, args: tidY); |
| 77 | Value tmp3 = create<arith::MulIOp>(args: int32Type, args: tmp2, args: dimX); |
| 78 | Value tmp4 = create<arith::MulIOp>(args: int32Type, args: dimX, args: dimY); |
| 79 | Value invocationIdx = create<arith::AddIOp>(args: int32Type, args: tmp3, args: tidX); |
| 80 | Value workgroupSize = create<arith::MulIOp>(args: int32Type, args: tmp4, args: dimZ); |
| 81 | |
| 82 | // Compute lane id (invocation id withing the subgroup). |
| 83 | Value subgroupMask = |
| 84 | create<arith::ConstantIntOp>(args: int32Type, args: kSubgroupSize - 1); |
| 85 | Value laneId = create<arith::AndIOp>(args: invocationIdx, args: subgroupMask); |
| 86 | Value isFirstLane = |
| 87 | create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args: laneId, |
| 88 | args: create<arith::ConstantIntOp>(args: int32Type, args: 0)); |
| 89 | |
| 90 | Value numThreadsWithSmallerSubgroupId = |
| 91 | create<arith::SubIOp>(args: invocationIdx, args: laneId); |
| 92 | // The number of active invocations starting from the current subgroup. |
| 93 | // The consumers do not require the value to be clamped to the size of the |
| 94 | // subgroup. |
| 95 | Value activeWidth = |
| 96 | create<arith::SubIOp>(args: workgroupSize, args: numThreadsWithSmallerSubgroupId); |
| 97 | |
| 98 | // Create factory for op which accumulates to values. |
| 99 | AccumulatorFactory accumFactory = getFactory(); |
| 100 | assert(accumFactory && "failed to create accumulator factory" ); |
| 101 | |
| 102 | // Reduce elements within each subgroup to produce the intermediate results. |
| 103 | Value subgroupReduce = createSubgroupReduce( |
| 104 | activeWidth, laneId, operand: reduceOp.getValue(), accumFactory); |
| 105 | |
| 106 | // Add workgroup buffer to parent function for intermediate result. |
| 107 | Value buffer = createWorkgroupBuffer(); |
| 108 | |
| 109 | // Write the intermediate results to workgroup memory, using the first lane |
| 110 | // of each subgroup. |
| 111 | createPredicatedBlock(condition: isFirstLane, predicatedOpsFactory: [&] { |
| 112 | Value subgroupId = getDivideBySubgroupSize(value: invocationIdx); |
| 113 | Value index = create<arith::IndexCastOp>(args: indexType, args: subgroupId); |
| 114 | create<memref::StoreOp>(args: subgroupReduce, args: buffer, args: index); |
| 115 | }); |
| 116 | create<gpu::BarrierOp>(); |
| 117 | |
| 118 | // Compute number of active subgroups. |
| 119 | Value biasedBlockSize = |
| 120 | create<arith::AddIOp>(args: int32Type, args: workgroupSize, args: subgroupMask); |
| 121 | Value numSubgroups = getDivideBySubgroupSize(value: biasedBlockSize); |
| 122 | Value isValidSubgroup = create<arith::CmpIOp>(args: arith::CmpIPredicate::slt, |
| 123 | args: invocationIdx, args: numSubgroups); |
| 124 | |
| 125 | // Use the first numSubgroups invocations to reduce the intermediate results |
| 126 | // from workgroup memory. The final result is written to workgroup memory |
| 127 | // again. |
| 128 | Value zero = create<arith::ConstantIndexOp>(args: 0); |
| 129 | createPredicatedBlock(condition: isValidSubgroup, predicatedOpsFactory: [&] { |
| 130 | Value index = create<arith::IndexCastOp>(args: indexType, args: invocationIdx); |
| 131 | Value value = create<memref::LoadOp>(args: valueType, args: buffer, args: index); |
| 132 | Value result = |
| 133 | createSubgroupReduce(activeWidth: numSubgroups, laneId, operand: value, accumFactory); |
| 134 | create<memref::StoreOp>(args: result, args: buffer, args: zero); |
| 135 | }); |
| 136 | |
| 137 | // Synchronize workgroup and load result from workgroup memory. |
| 138 | create<gpu::BarrierOp>(); |
| 139 | Value result = create<memref::LoadOp>(args: valueType, args: buffer, args: zero); |
| 140 | |
| 141 | rewriter.replaceOp(op: reduceOp, newValues: result); |
| 142 | } |
| 143 | |
| 144 | private: |
| 145 | // Shortcut to create an op from rewriter using loc as the first argument. |
| 146 | template <typename T, typename... Args> |
| 147 | T create(Args... args) { |
| 148 | return rewriter.create<T>(loc, std::forward<Args>(args)...); |
| 149 | } |
| 150 | |
| 151 | // Creates dimension op of type T, with the result casted to int32. |
| 152 | template <typename T> |
| 153 | Value getDimOp(gpu::Dimension dimension) { |
| 154 | Value dim = create<T>(indexType, dimension); |
| 155 | return create<arith::IndexCastOp>(args: int32Type, args: dim); |
| 156 | } |
| 157 | |
| 158 | /// Adds type to funcOp's workgroup attributions. |
| 159 | Value createWorkgroupBuffer() { |
| 160 | // TODO: Pick a proper location for the attribution. |
| 161 | auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get( |
| 162 | context: funcOp->getContext(), value: gpu::GPUDialect::getWorkgroupAddressSpace()); |
| 163 | auto bufferType = MemRefType::get(shape: {kSubgroupSize}, elementType: valueType, map: AffineMap{}, |
| 164 | memorySpace: workgroupMemoryAddressSpace); |
| 165 | return funcOp.addWorkgroupAttribution(type: bufferType, loc: rewriter.getUnknownLoc()); |
| 166 | } |
| 167 | |
| 168 | /// Returns an accumulator factory using either the op attribute or the body |
| 169 | /// region. |
| 170 | AccumulatorFactory getFactory() { |
| 171 | auto &body = reduceOp.getBody(); |
| 172 | if (!body.empty()) |
| 173 | return getFactory(body); |
| 174 | auto opAttr = reduceOp.getOp(); |
| 175 | if (opAttr) |
| 176 | return getFactory(opName: *opAttr); |
| 177 | return AccumulatorFactory(); |
| 178 | } |
| 179 | |
| 180 | /// Returns an accumulator factory that clones the body. The body's entry |
| 181 | /// block is expected to have 2 arguments. The gpu.yield return the |
| 182 | /// accumulated value of the same type. |
| 183 | AccumulatorFactory getFactory(Region &body) { |
| 184 | return [&body, this](Value lhs, Value rhs) -> Value { |
| 185 | Block *block = rewriter.getInsertionBlock(); |
| 186 | Block *split = rewriter.splitBlock(block, before: rewriter.getInsertionPoint()); |
| 187 | |
| 188 | // Insert accumulator body between split block. |
| 189 | IRMapping mapping; |
| 190 | mapping.map(from: body.getArgument(i: 0), to: lhs); |
| 191 | mapping.map(from: body.getArgument(i: 1), to: rhs); |
| 192 | rewriter.cloneRegionBefore(region&: body, parent&: *split->getParent(), |
| 193 | before: split->getIterator(), mapping); |
| 194 | |
| 195 | // Add branch before inserted body, into body. |
| 196 | block = block->getNextNode(); |
| 197 | create<cf::BranchOp>(args: block, args: ValueRange()); |
| 198 | |
| 199 | // Replace all gpu.yield ops with branch out of body. |
| 200 | for (; block != split; block = block->getNextNode()) { |
| 201 | Operation *terminator = block->getTerminator(); |
| 202 | if (!isa<gpu::YieldOp>(Val: terminator)) |
| 203 | continue; |
| 204 | rewriter.setInsertionPointToEnd(block); |
| 205 | rewriter.replaceOpWithNewOp<cf::BranchOp>( |
| 206 | op: terminator, args&: split, args: ValueRange(terminator->getOperand(idx: 0))); |
| 207 | } |
| 208 | |
| 209 | // Return accumulator result. |
| 210 | rewriter.setInsertionPointToStart(split); |
| 211 | return split->addArgument(type: lhs.getType(), loc: lhs.getLoc()); |
| 212 | }; |
| 213 | } |
| 214 | |
| 215 | /// Returns an accumulator factory that creates an op specified by opName. |
| 216 | AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { |
| 217 | return [opName, this](Value lhs, Value rhs) { |
| 218 | return vector::makeArithReduction(b&: rewriter, loc, |
| 219 | kind: convertReductionKind(mode: opName), v1: lhs, acc: rhs); |
| 220 | }; |
| 221 | } |
| 222 | |
| 223 | /// Creates an if-block skeleton and calls the two factories to generate the |
| 224 | /// ops in the `then` and `else` block.. |
| 225 | /// |
| 226 | /// llvm.cond_br %condition, ^then, ^continue |
| 227 | /// ^then: |
| 228 | /// %then_operands = `thenOpsFactory()` |
| 229 | /// llvm.br ^continue(%then_operands) |
| 230 | /// ^else: |
| 231 | /// %else_operands = `elseOpsFactory()` |
| 232 | /// llvm.br ^continue(%else_operands) |
| 233 | /// ^continue(%block_operands): |
| 234 | /// |
| 235 | template <typename ThenOpsFactory, typename ElseOpsFactory> |
| 236 | void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, |
| 237 | ElseOpsFactory &&elseOpsFactory) { |
| 238 | Block *currentBlock = rewriter.getInsertionBlock(); |
| 239 | auto currentPoint = rewriter.getInsertionPoint(); |
| 240 | |
| 241 | Block *thenBlock = rewriter.splitBlock(block: currentBlock, before: currentPoint); |
| 242 | Block *elseBlock = rewriter.splitBlock(block: thenBlock, before: thenBlock->begin()); |
| 243 | Block *continueBlock = rewriter.splitBlock(block: elseBlock, before: elseBlock->begin()); |
| 244 | |
| 245 | rewriter.setInsertionPointToEnd(currentBlock); |
| 246 | create<cf::CondBranchOp>(args: condition, args: thenBlock, |
| 247 | /*trueOperands=*/args: ArrayRef<Value>(), args: elseBlock, |
| 248 | /*falseOperands=*/args: ArrayRef<Value>()); |
| 249 | |
| 250 | rewriter.setInsertionPointToStart(thenBlock); |
| 251 | auto thenOperands = thenOpsFactory(); |
| 252 | create<cf::BranchOp>(continueBlock, thenOperands); |
| 253 | |
| 254 | rewriter.setInsertionPointToStart(elseBlock); |
| 255 | auto elseOperands = elseOpsFactory(); |
| 256 | create<cf::BranchOp>(continueBlock, elseOperands); |
| 257 | |
| 258 | assert(thenOperands.size() == elseOperands.size()); |
| 259 | rewriter.setInsertionPointToStart(continueBlock); |
| 260 | for (auto operand : thenOperands) |
| 261 | continueBlock->addArgument(type: operand.getType(), loc: operand.getLoc()); |
| 262 | } |
| 263 | |
| 264 | /// Shortcut for createIf with empty else block and no block operands. |
| 265 | template <typename Factory> |
| 266 | void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { |
| 267 | static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, |
| 268 | "predicatedOpsFactory should not return any value" ); |
| 269 | createIf( |
| 270 | condition, |
| 271 | [&] { |
| 272 | predicatedOpsFactory(); |
| 273 | return ArrayRef<Value>(); |
| 274 | }, |
| 275 | [&] { return ArrayRef<Value>(); }); |
| 276 | } |
| 277 | |
| 278 | /// Creates a reduction across the first activeWidth lanes of a subgroup, or |
| 279 | /// the entire subgroup if activeWidth is larger than the subgroup width. |
| 280 | /// The first lane returns the result, all others return values are undefined. |
| 281 | Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, |
| 282 | AccumulatorFactory &accumFactory) { |
| 283 | Value subgroupSize = create<arith::ConstantIntOp>(args: int32Type, args: kSubgroupSize); |
| 284 | Value isPartialSubgroup = create<arith::CmpIOp>(args: arith::CmpIPredicate::slt, |
| 285 | args: activeWidth, args: subgroupSize); |
| 286 | std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; |
| 287 | |
| 288 | createIf( |
| 289 | condition: isPartialSubgroup, |
| 290 | // Generate reduction over a (potentially) partial subgroup. |
| 291 | thenOpsFactory: [&] { |
| 292 | Value value = operand; |
| 293 | // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source |
| 294 | // lane is within the active range. The accumulated value is available |
| 295 | // in the first lane. |
| 296 | for (int i = 1; i < kSubgroupSize; i <<= 1) { |
| 297 | Value offset = create<arith::ConstantIntOp>(args: int32Type, args: i); |
| 298 | auto shuffleOp = create<gpu::ShuffleOp>( |
| 299 | args: shuffleType, args: value, args: offset, args: activeWidth, args: gpu::ShuffleMode::XOR); |
| 300 | // Skip the accumulation if the shuffle op read from a lane outside |
| 301 | // of the active range. |
| 302 | createIf( |
| 303 | condition: shuffleOp.getResult(i: 1), |
| 304 | thenOpsFactory: [&] { |
| 305 | return SmallVector<Value, 1>{ |
| 306 | accumFactory(value, shuffleOp.getResult(i: 0))}; |
| 307 | }, |
| 308 | elseOpsFactory: [&] { return llvm::ArrayRef(value); }); |
| 309 | value = rewriter.getInsertionBlock()->getArgument(i: 0); |
| 310 | } |
| 311 | return SmallVector<Value, 1>{value}; |
| 312 | }, |
| 313 | // Generate a reduction over the entire subgroup. This is a |
| 314 | // specialization of the above reduction with unconditional |
| 315 | // accumulation. |
| 316 | elseOpsFactory: [&] { |
| 317 | Value value = operand; |
| 318 | for (int i = 1; i < kSubgroupSize; i <<= 1) { |
| 319 | Value offset = create<arith::ConstantIntOp>(args: int32Type, args: i); |
| 320 | auto shuffleOp = |
| 321 | create<gpu::ShuffleOp>(args: shuffleType, args: value, args: offset, args: subgroupSize, |
| 322 | args: gpu::ShuffleMode::XOR); |
| 323 | value = accumFactory(value, shuffleOp.getResult(i: 0)); |
| 324 | } |
| 325 | return SmallVector<Value, 1>{value}; |
| 326 | }); |
| 327 | return rewriter.getInsertionBlock()->getArgument(i: 0); |
| 328 | } |
| 329 | |
| 330 | /// Returns value divided by the subgroup size (i.e. 32). |
| 331 | Value getDivideBySubgroupSize(Value value) { |
| 332 | Value subgroupSize = create<arith::ConstantIntOp>(args: int32Type, args: kSubgroupSize); |
| 333 | return create<arith::DivSIOp>(args: int32Type, args: value, args: subgroupSize); |
| 334 | } |
| 335 | |
| 336 | gpu::GPUFuncOp funcOp; |
| 337 | gpu::AllReduceOp reduceOp; |
| 338 | PatternRewriter &rewriter; |
| 339 | |
| 340 | Location loc; |
| 341 | Type valueType; |
| 342 | Type indexType; |
| 343 | IntegerType int32Type; |
| 344 | |
| 345 | static constexpr int kSubgroupSize = 32; |
| 346 | }; |
| 347 | |
| 348 | struct GpuAllReduceRewrite : public RewritePattern { |
| 349 | explicit GpuAllReduceRewrite(MLIRContext *context) |
| 350 | : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} |
| 351 | |
| 352 | LogicalResult matchAndRewrite(Operation *op, |
| 353 | PatternRewriter &rewriter) const override { |
| 354 | auto funcOp = cast<gpu::GPUFuncOp>(Val: op); |
| 355 | |
| 356 | SmallVector<gpu::AllReduceOp> reduceOps; |
| 357 | auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult { |
| 358 | if (!reduceOp.getUniform()) |
| 359 | return WalkResult::interrupt(); |
| 360 | |
| 361 | reduceOps.emplace_back(Args&: reduceOp); |
| 362 | return WalkResult::advance(); |
| 363 | }; |
| 364 | |
| 365 | if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty()) |
| 366 | return rewriter.notifyMatchFailure( |
| 367 | arg&: op, msg: "Non uniform reductions are not supported yet." ); |
| 368 | |
| 369 | for (gpu::AllReduceOp reduceOp : reduceOps) |
| 370 | GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); |
| 371 | |
| 372 | return success(); |
| 373 | } |
| 374 | }; |
| 375 | } // namespace |
| 376 | |
| 377 | void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { |
| 378 | patterns.add<GpuAllReduceRewrite>(arg: patterns.getContext()); |
| 379 | } |
| 380 | |