1 | //===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements in-dialect rewriting of the gpu.subgroup_id op for archs |
10 | // where: |
11 | // subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
16 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
17 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | |
22 | using namespace mlir; |
23 | |
24 | namespace { |
25 | struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> { |
26 | using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern; |
27 | |
28 | LogicalResult matchAndRewrite(gpu::SubgroupIdOp op, |
29 | PatternRewriter &rewriter) const override { |
30 | // Calculation of the thread's subgroup identifier. |
31 | // |
32 | // The process involves mapping the thread's 3D identifier within its |
33 | // block (b_id.x, b_id.y, b_id.z) to a 1D linear index. |
34 | // This linearization assumes a layout where the x-dimension (w_dim.x) |
35 | // varies most rapidly (i.e., it is the innermost dimension). |
36 | // |
37 | // The formula for the linearized thread index is: |
38 | // L = tid.x + dim.x * (tid.y + (dim.y * tid.z)) |
39 | // |
40 | // Subsequently, the range of linearized indices [0, N_threads-1] is |
41 | // divided into consecutive, non-overlapping segments, each representing |
42 | // a subgroup of size 'subgroup_size'. |
43 | // |
44 | // Example Partitioning (N = subgroup_size): |
45 | // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... | |
46 | // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... | |
47 | // |
48 | // The subgroup identifier is obtained via integer division of the |
49 | // linearized thread index by the predefined 'subgroup_size'. |
50 | // |
51 | // subgroup_id = floor( L / subgroup_size ) |
52 | // = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / |
53 | // subgroup_size |
54 | |
55 | Location loc = op->getLoc(); |
56 | Type indexType = rewriter.getIndexType(); |
57 | |
58 | Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); |
59 | Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y); |
60 | Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
61 | Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y); |
62 | Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z); |
63 | |
64 | Value dimYxIdZ = rewriter.create<arith::MulIOp>(loc, indexType, dimY, tidZ); |
65 | Value dimYxIdZPlusIdY = |
66 | rewriter.create<arith::AddIOp>(loc, indexType, dimYxIdZ, tidY); |
67 | Value dimYxIdZPlusIdYTimesDimX = |
68 | rewriter.create<arith::MulIOp>(loc, indexType, dimX, dimYxIdZPlusIdY); |
69 | Value IdXPlusDimYxIdZPlusIdYTimesDimX = rewriter.create<arith::AddIOp>( |
70 | loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX); |
71 | Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>( |
72 | loc, rewriter.getIndexType(), /*upper_bound = */ nullptr); |
73 | Value subgroupIdOp = rewriter.create<arith::DivUIOp>( |
74 | loc, indexType, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize); |
75 | rewriter.replaceOp(op, {subgroupIdOp}); |
76 | return success(); |
77 | } |
78 | }; |
79 | |
80 | } // namespace |
81 | |
82 | void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) { |
83 | patterns.add<GpuSubgroupIdRewriter>(arg: patterns.getContext()); |
84 | } |
85 | |