| 1 | //===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements in-dialect rewriting of the gpu.subgroup_id op for archs |
| 10 | // where: |
| 11 | // subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 16 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 17 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
| 18 | #include "mlir/IR/Builders.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/Pass/Pass.h" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | |
| 24 | namespace { |
| 25 | struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> { |
| 26 | using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern; |
| 27 | |
| 28 | LogicalResult matchAndRewrite(gpu::SubgroupIdOp op, |
| 29 | PatternRewriter &rewriter) const override { |
| 30 | // Calculation of the thread's subgroup identifier. |
| 31 | // |
| 32 | // The process involves mapping the thread's 3D identifier within its |
| 33 | // block (b_id.x, b_id.y, b_id.z) to a 1D linear index. |
| 34 | // This linearization assumes a layout where the x-dimension (w_dim.x) |
| 35 | // varies most rapidly (i.e., it is the innermost dimension). |
| 36 | // |
| 37 | // The formula for the linearized thread index is: |
| 38 | // L = tid.x + dim.x * (tid.y + (dim.y * tid.z)) |
| 39 | // |
| 40 | // Subsequently, the range of linearized indices [0, N_threads-1] is |
| 41 | // divided into consecutive, non-overlapping segments, each representing |
| 42 | // a subgroup of size 'subgroup_size'. |
| 43 | // |
| 44 | // Example Partitioning (N = subgroup_size): |
| 45 | // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... | |
| 46 | // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... | |
| 47 | // |
| 48 | // The subgroup identifier is obtained via integer division of the |
| 49 | // linearized thread index by the predefined 'subgroup_size'. |
| 50 | // |
| 51 | // subgroup_id = floor( L / subgroup_size ) |
| 52 | // = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / |
| 53 | // subgroup_size |
| 54 | |
| 55 | Location loc = op->getLoc(); |
| 56 | Type indexType = rewriter.getIndexType(); |
| 57 | |
| 58 | Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); |
| 59 | Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y); |
| 60 | Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
| 61 | Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y); |
| 62 | Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z); |
| 63 | |
| 64 | Value dimYxIdZ = rewriter.create<arith::MulIOp>(loc, indexType, dimY, tidZ); |
| 65 | Value dimYxIdZPlusIdY = |
| 66 | rewriter.create<arith::AddIOp>(loc, indexType, dimYxIdZ, tidY); |
| 67 | Value dimYxIdZPlusIdYTimesDimX = |
| 68 | rewriter.create<arith::MulIOp>(loc, indexType, dimX, dimYxIdZPlusIdY); |
| 69 | Value IdXPlusDimYxIdZPlusIdYTimesDimX = rewriter.create<arith::AddIOp>( |
| 70 | loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX); |
| 71 | Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>( |
| 72 | loc, rewriter.getIndexType(), /*upper_bound = */ nullptr); |
| 73 | Value subgroupIdOp = rewriter.create<arith::DivUIOp>( |
| 74 | loc, indexType, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize); |
| 75 | rewriter.replaceOp(op, {subgroupIdOp}); |
| 76 | return success(); |
| 77 | } |
| 78 | }; |
| 79 | |
| 80 | } // namespace |
| 81 | |
| 82 | void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) { |
| 83 | patterns.add<GpuSubgroupIdRewriter>(arg: patterns.getContext()); |
| 84 | } |
| 85 | |