1//===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
10// where:
11// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/GPU/Transforms/Passes.h"
17#include "mlir/Dialect/Index/IR/IndexOps.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Pass/Pass.h"
21
22using namespace mlir;
23
24namespace {
25struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
26 using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
27
28 LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
29 PatternRewriter &rewriter) const override {
30 // Calculation of the thread's subgroup identifier.
31 //
32 // The process involves mapping the thread's 3D identifier within its
33 // block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
34 // This linearization assumes a layout where the x-dimension (w_dim.x)
35 // varies most rapidly (i.e., it is the innermost dimension).
36 //
37 // The formula for the linearized thread index is:
38 // L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
39 //
40 // Subsequently, the range of linearized indices [0, N_threads-1] is
41 // divided into consecutive, non-overlapping segments, each representing
42 // a subgroup of size 'subgroup_size'.
43 //
44 // Example Partitioning (N = subgroup_size):
45 // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
46 // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
47 //
48 // The subgroup identifier is obtained via integer division of the
49 // linearized thread index by the predefined 'subgroup_size'.
50 //
51 // subgroup_id = floor( L / subgroup_size )
52 // = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
53 // subgroup_size
54
55 Location loc = op->getLoc();
56 Type indexType = rewriter.getIndexType();
57
58 Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
59 Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
60 Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
61 Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
62 Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
63
64 Value dimYxIdZ = rewriter.create<arith::MulIOp>(loc, indexType, dimY, tidZ);
65 Value dimYxIdZPlusIdY =
66 rewriter.create<arith::AddIOp>(loc, indexType, dimYxIdZ, tidY);
67 Value dimYxIdZPlusIdYTimesDimX =
68 rewriter.create<arith::MulIOp>(loc, indexType, dimX, dimYxIdZPlusIdY);
69 Value IdXPlusDimYxIdZPlusIdYTimesDimX = rewriter.create<arith::AddIOp>(
70 loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX);
71 Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
72 loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
73 Value subgroupIdOp = rewriter.create<arith::DivUIOp>(
74 loc, indexType, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
75 rewriter.replaceOp(op, {subgroupIdOp});
76 return success();
77 }
78};
79
80} // namespace
81
82void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
83 patterns.add<GpuSubgroupIdRewriter>(arg: patterns.getContext());
84}
85

source code of mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp