1 | //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
10 | #include "mlir/IR/Matchers.h" |
11 | #include "mlir/Interfaces/FunctionInterfaces.h" |
12 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
13 | #include "llvm/ADT/STLForwardCompat.h" |
14 | #include "llvm/Support/ErrorHandling.h" |
15 | #include "llvm/Support/MathExtras.h" |
16 | #include <optional> |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::gpu; |
20 | |
21 | // Maximum grid and block dimensions of all known GPUs are less than 2^32. |
22 | static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max(); |
23 | // Maximum cluster size |
24 | static constexpr uint64_t kMaxClusterDim = 8; |
25 | // Maximum subgroups are no larger than 128. |
26 | static constexpr uint64_t kMaxSubgroupSize = 128; |
27 | |
28 | static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) { |
29 | unsigned width = IndexType::kInternalStorageBitWidth; |
30 | return ConstantIntRanges::fromUnsigned(umin: APInt(width, umin), |
31 | umax: APInt(width, umax)); |
32 | } |
33 | |
34 | namespace { |
35 | enum class LaunchDims : uint32_t { Block = 0, Grid = 1 }; |
36 | } // end namespace |
37 | |
38 | /// If the operation `op` is in a context that is annotated with maximum |
39 | /// launch dimensions (a launch op with constant block or grid |
40 | /// sizes or a launch_func op with the appropriate dimensions), return |
41 | /// the bound on the maximum size of the dimension that the op is querying. |
42 | /// IDs will be one less than this bound. |
43 | |
44 | static Value valueByDim(KernelDim3 dims, Dimension dim) { |
45 | switch (dim) { |
46 | case Dimension::x: |
47 | return dims.x; |
48 | case Dimension::y: |
49 | return dims.y; |
50 | case Dimension::z: |
51 | return dims.z; |
52 | } |
53 | llvm_unreachable("All dimension enum cases handled above" ); |
54 | } |
55 | |
56 | static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); } |
57 | |
58 | static std::optional<uint64_t> |
59 | getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) { |
60 | DenseI32ArrayAttr bounds; |
61 | switch (dims) { |
62 | case LaunchDims::Block: |
63 | bounds = func.getKnownBlockSizeAttr(); |
64 | break; |
65 | case LaunchDims::Grid: |
66 | bounds = func.getKnownGridSizeAttr(); |
67 | break; |
68 | } |
69 | if (!bounds) |
70 | return std::nullopt; |
71 | if (bounds.size() < static_cast<uint32_t>(dim)) |
72 | return std::nullopt; |
73 | return zext(bounds[static_cast<uint32_t>(dim)]); |
74 | } |
75 | |
76 | static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func, |
77 | StringRef attrName, |
78 | Dimension dim) { |
79 | auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName); |
80 | if (!bounds) |
81 | return std::nullopt; |
82 | if (bounds.size() < static_cast<uint32_t>(dim)) |
83 | return std::nullopt; |
84 | return zext(bounds[static_cast<uint32_t>(dim)]); |
85 | } |
86 | |
87 | template <typename Op> |
88 | static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) { |
89 | Dimension dim = op.getDimension(); |
90 | if (auto launch = op->template getParentOfType<LaunchOp>()) { |
91 | KernelDim3 bounds; |
92 | switch (type) { |
93 | case LaunchDims::Block: |
94 | bounds = launch.getBlockSizeOperandValues(); |
95 | break; |
96 | case LaunchDims::Grid: |
97 | bounds = launch.getGridSizeOperandValues(); |
98 | break; |
99 | } |
100 | Value maybeBound = valueByDim(bounds, dim); |
101 | APInt value; |
102 | if (matchPattern(maybeBound, m_ConstantInt(&value))) |
103 | return value.getZExtValue(); |
104 | } |
105 | |
106 | if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) { |
107 | auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim); |
108 | if (inherentAttr) |
109 | return inherentAttr; |
110 | } |
111 | if (auto func = op->template getParentOfType<FunctionOpInterface>()) { |
112 | StringRef attrName; |
113 | switch (type) { |
114 | case LaunchDims::Block: |
115 | attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr(); |
116 | break; |
117 | case LaunchDims::Grid: |
118 | attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr(); |
119 | break; |
120 | } |
121 | auto discardableAttr = getKnownLaunchAttr(func, attrName, dim); |
122 | if (discardableAttr) |
123 | return discardableAttr; |
124 | } |
125 | return std::nullopt; |
126 | } |
127 | |
128 | void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
129 | SetIntRangeFn setResultRange) { |
130 | uint64_t max = kMaxDim; |
131 | if (auto specified = getUpperBound()) |
132 | max = specified->getZExtValue(); |
133 | setResultRange(getResult(), getIndexRange(1, max)); |
134 | } |
135 | |
136 | void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
137 | SetIntRangeFn setResultRange) { |
138 | uint64_t max = kMaxClusterDim; |
139 | if (auto specified = getUpperBound()) |
140 | max = specified->getZExtValue(); |
141 | setResultRange(getResult(), getIndexRange(1, max)); |
142 | } |
143 | |
144 | void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
145 | SetIntRangeFn setResultRange) { |
146 | uint64_t max = kMaxDim; |
147 | if (auto specified = getUpperBound()) |
148 | max = specified->getZExtValue(); |
149 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
150 | } |
151 | |
152 | void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
153 | SetIntRangeFn setResultRange) { |
154 | uint64_t max = kMaxClusterDim; |
155 | if (auto specified = getUpperBound()) |
156 | max = specified->getZExtValue(); |
157 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
158 | } |
159 | |
160 | void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
161 | SetIntRangeFn setResultRange) { |
162 | std::optional<uint64_t> knownVal = |
163 | getKnownLaunchDim(*this, LaunchDims::Block); |
164 | if (knownVal) |
165 | return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); |
166 | ; |
167 | uint64_t max = kMaxDim; |
168 | if (auto specified = getUpperBound()) |
169 | max = specified->getZExtValue(); |
170 | setResultRange(getResult(), getIndexRange(1, max)); |
171 | } |
172 | |
173 | void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
174 | SetIntRangeFn setResultRange) { |
175 | uint64_t max = kMaxDim; |
176 | if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid)) |
177 | max = fromContext.value(); |
178 | if (auto specified = getUpperBound()) |
179 | max = specified->getZExtValue(); |
180 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
181 | } |
182 | |
183 | void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
184 | SetIntRangeFn setResultRange) { |
185 | std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid); |
186 | if (knownVal) |
187 | return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal)); |
188 | uint64_t max = kMaxDim; |
189 | if (auto specified = getUpperBound()) |
190 | max = specified->getZExtValue(); |
191 | setResultRange(getResult(), getIndexRange(1, max)); |
192 | } |
193 | |
194 | void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
195 | SetIntRangeFn setResultRange) { |
196 | uint64_t max = kMaxDim; |
197 | if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block)) |
198 | max = fromContext.value(); |
199 | if (auto specified = getUpperBound()) |
200 | max = specified->getZExtValue(); |
201 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
202 | } |
203 | |
204 | void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
205 | SetIntRangeFn setResultRange) { |
206 | uint64_t max = kMaxSubgroupSize; |
207 | if (auto specified = getUpperBound()) |
208 | max = specified->getZExtValue(); |
209 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
210 | } |
211 | |
212 | void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
213 | SetIntRangeFn setResultRange) { |
214 | uint64_t max = kMaxDim; |
215 | if (auto specified = getUpperBound()) |
216 | max = specified->getZExtValue(); |
217 | setResultRange(getResult(), getIndexRange(0, max - 1ULL)); |
218 | } |
219 | |
220 | void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
221 | SetIntRangeFn setResultRange) { |
222 | if (auto specified = getUpperBound()) |
223 | return setResultRange(getResult(), |
224 | getIndexRange(0, specified->getZExtValue() - 1ULL)); |
225 | |
226 | uint64_t blockDimMax = |
227 | getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim); |
228 | uint64_t gridDimMax = |
229 | getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim); |
230 | setResultRange(getResult(), |
231 | getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL)); |
232 | } |
233 | |
234 | void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
235 | SetIntRangeFn setResultRange) { |
236 | uint64_t max = kMaxDim; |
237 | if (auto specified = getUpperBound()) |
238 | max = specified->getZExtValue(); |
239 | setResultRange(getResult(), getIndexRange(1, max)); |
240 | } |
241 | |
242 | void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>, |
243 | SetIntRangeFn setResultRange) { |
244 | uint64_t max = kMaxSubgroupSize; |
245 | if (auto specified = getUpperBound()) |
246 | max = specified->getZExtValue(); |
247 | setResultRange(getResult(), getIndexRange(1, max)); |
248 | } |
249 | |
250 | void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, |
251 | SetIntRangeFn setResultRange) { |
252 | auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult, |
253 | Value idxResult) { |
254 | if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth) |
255 | return; |
256 | ConstantIntRanges dimRange = |
257 | argRange.intersection(getIndexRange(1, kMaxDim)); |
258 | setResultRange(dimResult, dimRange); |
259 | ConstantIntRanges idxRange = |
260 | getIndexRange(0, dimRange.umax().getZExtValue() - 1); |
261 | setResultRange(idxResult, idxRange); |
262 | }; |
263 | |
264 | argRanges = argRanges.drop_front(getAsyncDependencies().size()); |
265 | KernelDim3 gridDims = getGridSize(); |
266 | KernelDim3 blockIds = getBlockIds(); |
267 | setRange(argRanges[0], gridDims.x, blockIds.x); |
268 | setRange(argRanges[1], gridDims.y, blockIds.y); |
269 | setRange(argRanges[2], gridDims.z, blockIds.z); |
270 | KernelDim3 blockDims = getBlockSize(); |
271 | KernelDim3 threadIds = getThreadIds(); |
272 | setRange(argRanges[3], blockDims.x, threadIds.x); |
273 | setRange(argRanges[4], blockDims.y, threadIds.y); |
274 | setRange(argRanges[5], blockDims.z, threadIds.z); |
275 | } |
276 | |