1 | //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
10 | #include "mlir/IR/Matchers.h" |
11 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
12 | #include "llvm/ADT/STLForwardCompat.h" |
13 | #include "llvm/Support/ErrorHandling.h" |
14 | #include "llvm/Support/MathExtras.h" |
15 | #include <optional> |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::gpu; |
19 | |
20 | // Maximum grid and block dimensions of all known GPUs are less than 2^32. |
21 | static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max(); |
22 | // Maximum cluster size |
23 | static constexpr uint64_t kMaxClusterDim = 8; |
24 | // Maximum subgroups are no larger than 128. |
25 | static constexpr uint64_t kMaxSubgroupSize = 128; |
26 | |
27 | static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) { |
28 | unsigned width = IndexType::kInternalStorageBitWidth; |
29 | return ConstantIntRanges::fromUnsigned(umin: APInt(width, umin), |
30 | umax: APInt(width, umax)); |
31 | } |
32 | |
33 | namespace { |
34 | enum class LaunchDims : uint32_t { Block = 0, Grid = 1 }; |
35 | } // end namespace |
36 | |
37 | /// If the operation `op` is in a context that is annotated with maximum |
38 | /// launch dimensions (a launch op with constant block or grid |
39 | /// sizes or a launch_func op with the appropriate dimensions), return |
40 | /// the bound on the maximum size of the dimension that the op is querying. |
41 | /// IDs will be one less than this bound. |
42 | |
43 | static Value valueByDim(KernelDim3 dims, Dimension dim) { |
44 | switch (dim) { |
45 | case Dimension::x: |
46 | return dims.x; |
47 | case Dimension::y: |
48 | return dims.y; |
49 | case Dimension::z: |
50 | return dims.z; |
51 | } |
52 | llvm_unreachable("All dimension enum cases handled above" ); |
53 | } |
54 | |
55 | static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); } |
56 | |
57 | template <typename Op> |
58 | static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) { |
59 | Dimension dim = op.getDimension(); |
60 | if (auto launch = op->template getParentOfType<LaunchOp>()) { |
61 | KernelDim3 bounds; |
62 | switch (type) { |
63 | case LaunchDims::Block: |
64 | bounds = launch.getBlockSizeOperandValues(); |
65 | break; |
66 | case LaunchDims::Grid: |
67 | bounds = launch.getGridSizeOperandValues(); |
68 | break; |
69 | } |
70 | Value maybeBound = valueByDim(bounds, dim); |
71 | APInt value; |
72 | if (matchPattern(maybeBound, m_ConstantInt(&value))) |
73 | return value.getZExtValue(); |
74 | } |
75 | |
76 | if (auto func = op->template getParentOfType<GPUFuncOp>()) { |
77 | switch (type) { |
78 | case LaunchDims::Block: |
79 | return llvm::transformOptional(func.getKnownBlockSize(dim), zext); |
80 | case LaunchDims::Grid: |
81 | return llvm::transformOptional(func.getKnownGridSize(dim), zext); |
82 | } |
83 | } |
84 | return std::nullopt; |
85 | } |
86 | |
87 | void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
88 | SetIntRangeFn setResultRange) { |
89 | setResultRange(getResult(), getIndexRange(1, kMaxClusterDim)); |
90 | } |
91 | |
92 | void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
93 | SetIntRangeFn setResultRange) { |
94 | uint64_t max = kMaxClusterDim; |
95 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
96 | } |
97 | |
98 | void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
99 | SetIntRangeFn setResultRange) { |
100 | std::optional<uint64_t> knownVal = |
101 | getKnownLaunchDim(*this, LaunchDims::Block); |
102 | if (knownVal) |
103 | setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); |
104 | else |
105 | setResultRange(getResult(), getIndexRange(1, kMaxDim)); |
106 | } |
107 | |
108 | void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
109 | SetIntRangeFn setResultRange) { |
110 | uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim); |
111 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
112 | } |
113 | |
114 | void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
115 | SetIntRangeFn setResultRange) { |
116 | std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid); |
117 | if (knownVal) |
118 | setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); |
119 | else |
120 | setResultRange(getResult(), getIndexRange(1, kMaxDim)); |
121 | } |
122 | |
123 | void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
124 | SetIntRangeFn setResultRange) { |
125 | uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim); |
126 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
127 | } |
128 | |
129 | void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
130 | SetIntRangeFn setResultRange) { |
131 | setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL)); |
132 | } |
133 | |
134 | void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
135 | SetIntRangeFn setResultRange) { |
136 | setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL)); |
137 | } |
138 | |
139 | void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
140 | SetIntRangeFn setResultRange) { |
141 | uint64_t blockDimMax = |
142 | getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim); |
143 | uint64_t gridDimMax = |
144 | getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim); |
145 | setResultRange(getResult(), |
146 | getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL)); |
147 | } |
148 | |
149 | void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
150 | SetIntRangeFn setResultRange) { |
151 | setResultRange(getResult(), getIndexRange(1, kMaxDim)); |
152 | } |
153 | |
154 | void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
155 | SetIntRangeFn setResultRange) { |
156 | setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize)); |
157 | } |
158 | |
159 | void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
160 | SetIntRangeFn setResultRange) { |
161 | auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult, |
162 | Value idxResult) { |
163 | if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth) |
164 | return; |
165 | ConstantIntRanges dimRange = |
166 | argRange.intersection(getIndexRange(1, kMaxDim)); |
167 | setResultRange(dimResult, dimRange); |
168 | ConstantIntRanges idxRange = |
169 | getIndexRange(0, dimRange.umax().getZExtValue() - 1); |
170 | setResultRange(idxResult, idxRange); |
171 | }; |
172 | |
173 | argRanges = argRanges.drop_front(getAsyncDependencies().size()); |
174 | KernelDim3 gridDims = getGridSize(); |
175 | KernelDim3 blockIds = getBlockIds(); |
176 | setRange(argRanges[0], gridDims.x, blockIds.x); |
177 | setRange(argRanges[1], gridDims.y, blockIds.y); |
178 | setRange(argRanges[2], gridDims.z, blockIds.z); |
179 | KernelDim3 blockDims = getBlockSize(); |
180 | KernelDim3 threadIds = getThreadIds(); |
181 | setRange(argRanges[3], blockDims.x, threadIds.x); |
182 | setRange(argRanges[4], blockDims.y, threadIds.y); |
183 | setRange(argRanges[5], blockDims.z, threadIds.z); |
184 | } |
185 | |