1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/GPU/IR/GPUDialect.h"
10#include "mlir/IR/Matchers.h"
11#include "mlir/Interfaces/FunctionInterfaces.h"
12#include "mlir/Interfaces/InferIntRangeInterface.h"
13#include "llvm/Support/ErrorHandling.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::gpu;
18
19// Maximum grid and block dimensions of all known GPUs are less than 2^32.
20static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
21// Maximum cluster size
22static constexpr uint64_t kMaxClusterDim = 8;
23// Maximum subgroups are no larger than 128.
24static constexpr uint64_t kMaxSubgroupSize = 128;
25
26static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
27 unsigned width = IndexType::kInternalStorageBitWidth;
28 return ConstantIntRanges::fromUnsigned(umin: APInt(width, umin),
29 umax: APInt(width, umax));
30}
31
32namespace {
33enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
34} // end namespace
35
36/// If the operation `op` is in a context that is annotated with maximum
37/// launch dimensions (a launch op with constant block or grid
38/// sizes or a launch_func op with the appropriate dimensions), return
39/// the bound on the maximum size of the dimension that the op is querying.
40/// IDs will be one less than this bound.
41
42static Value valueByDim(KernelDim3 dims, Dimension dim) {
43 switch (dim) {
44 case Dimension::x:
45 return dims.x;
46 case Dimension::y:
47 return dims.y;
48 case Dimension::z:
49 return dims.z;
50 }
51 llvm_unreachable("All dimension enum cases handled above");
52}
53
54static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
55
56static std::optional<uint64_t>
57getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) {
58 DenseI32ArrayAttr bounds;
59 switch (dims) {
60 case LaunchDims::Block:
61 bounds = func.getKnownBlockSizeAttr();
62 break;
63 case LaunchDims::Grid:
64 bounds = func.getKnownGridSizeAttr();
65 break;
66 }
67 if (!bounds)
68 return std::nullopt;
69 if (bounds.size() < static_cast<uint32_t>(dim))
70 return std::nullopt;
71 return zext(arg: bounds[static_cast<uint32_t>(dim)]);
72}
73
74static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func,
75 StringRef attrName,
76 Dimension dim) {
77 auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(name: attrName);
78 if (!bounds)
79 return std::nullopt;
80 if (bounds.size() < static_cast<uint32_t>(dim))
81 return std::nullopt;
82 return zext(arg: bounds[static_cast<uint32_t>(dim)]);
83}
84
85template <typename Op>
86static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
87 Dimension dim = op.getDimension();
88 if (auto launch = op->template getParentOfType<LaunchOp>()) {
89 KernelDim3 bounds;
90 switch (type) {
91 case LaunchDims::Block:
92 bounds = launch.getBlockSizeOperandValues();
93 break;
94 case LaunchDims::Grid:
95 bounds = launch.getGridSizeOperandValues();
96 break;
97 }
98 Value maybeBound = valueByDim(dims: bounds, dim);
99 APInt value;
100 if (matchPattern(value: maybeBound, pattern: m_ConstantInt(bind_value: &value)))
101 return value.getZExtValue();
102 }
103
104 if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) {
105 auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim);
106 if (inherentAttr)
107 return inherentAttr;
108 }
109 if (auto func = op->template getParentOfType<FunctionOpInterface>()) {
110 StringRef attrName;
111 switch (type) {
112 case LaunchDims::Block:
113 attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();
114 break;
115 case LaunchDims::Grid:
116 attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();
117 break;
118 }
119 auto discardableAttr = getKnownLaunchAttr(func, attrName, dim);
120 if (discardableAttr)
121 return discardableAttr;
122 }
123 return std::nullopt;
124}
125
126void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
127 SetIntRangeFn setResultRange) {
128 uint64_t max = kMaxDim;
129 if (auto specified = getUpperBound())
130 max = specified->getZExtValue();
131 setResultRange(getResult(), getIndexRange(umin: 1, umax: max));
132}
133
134void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
135 SetIntRangeFn setResultRange) {
136 uint64_t max = kMaxClusterDim;
137 if (auto specified = getUpperBound())
138 max = specified->getZExtValue();
139 setResultRange(getResult(), getIndexRange(umin: 1, umax: max));
140}
141
142void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
143 SetIntRangeFn setResultRange) {
144 uint64_t max = kMaxDim;
145 if (auto specified = getUpperBound())
146 max = specified->getZExtValue();
147 setResultRange(getResult(), getIndexRange(umin: 0, umax: max - 1ULL));
148}
149
150void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
151 SetIntRangeFn setResultRange) {
152 uint64_t max = kMaxClusterDim;
153 if (auto specified = getUpperBound())
154 max = specified->getZExtValue();
155 setResultRange(getResult(), getIndexRange(umin: 0, umax: max - 1ULL));
156}
157
158void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
159 SetIntRangeFn setResultRange) {
160 std::optional<uint64_t> knownVal =
161 getKnownLaunchDim(op: *this, type: LaunchDims::Block);
162 if (knownVal)
163 return setResultRange(getResult(), getIndexRange(umin: *knownVal, umax: *knownVal));
164 ;
165 uint64_t max = kMaxDim;
166 if (auto specified = getUpperBound())
167 max = specified->getZExtValue();
168 setResultRange(getResult(), getIndexRange(umin: 1, umax: max));
169}
170
171void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
172 SetIntRangeFn setResultRange) {
173 uint64_t max = kMaxDim;
174 if (auto fromContext = getKnownLaunchDim(op: *this, type: LaunchDims::Grid))
175 max = fromContext.value();
176 if (auto specified = getUpperBound())
177 max = specified->getZExtValue();
178 setResultRange(getResult(), getIndexRange(umin: 0, umax: max - 1ULL));
179}
180
181void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
182 SetIntRangeFn setResultRange) {
183 std::optional<uint64_t> knownVal = getKnownLaunchDim(op: *this, type: LaunchDims::Grid);
184 if (knownVal)
185 return setResultRange(getResult(), getIndexRange(umin: *knownVal, umax: *knownVal));
186 uint64_t max = kMaxDim;
187 if (auto specified = getUpperBound())
188 max = specified->getZExtValue();
189 setResultRange(getResult(), getIndexRange(umin: 1, umax: max));
190}
191
192void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
193 SetIntRangeFn setResultRange) {
194 uint64_t max = kMaxDim;
195 if (auto fromContext = getKnownLaunchDim(op: *this, type: LaunchDims::Block))
196 max = fromContext.value();
197 if (auto specified = getUpperBound())
198 max = specified->getZExtValue();
199 setResultRange(getResult(), getIndexRange(umin: 0, umax: max - 1ULL));
200}
201
202void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
203 SetIntRangeFn setResultRange) {
204 uint64_t max = kMaxSubgroupSize;
205 if (auto specified = getUpperBound())
206 max = specified->getZExtValue();
207 setResultRange(getResult(), getIndexRange(umin: 0, umax: max - 1ULL));
208}
209
210void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
211 SetIntRangeFn setResultRange) {
212 uint64_t max = kMaxDim;
213 if (auto specified = getUpperBound())
214 max = specified->getZExtValue();
215 setResultRange(getResult(), getIndexRange(umin: 0, umax: max - 1ULL));
216}
217
218void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
219 SetIntRangeFn setResultRange) {
220 if (auto specified = getUpperBound())
221 return setResultRange(getResult(),
222 getIndexRange(umin: 0, umax: specified->getZExtValue() - 1ULL));
223
224 uint64_t blockDimMax =
225 getKnownLaunchDim(op: *this, type: LaunchDims::Block).value_or(u: kMaxDim);
226 uint64_t gridDimMax =
227 getKnownLaunchDim(op: *this, type: LaunchDims::Grid).value_or(u: kMaxDim);
228 setResultRange(getResult(),
229 getIndexRange(umin: 0, umax: (blockDimMax * gridDimMax) - 1ULL));
230}
231
232void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
233 SetIntRangeFn setResultRange) {
234 uint64_t max = kMaxDim;
235 if (auto specified = getUpperBound())
236 max = specified->getZExtValue();
237 setResultRange(getResult(), getIndexRange(umin: 1, umax: max));
238}
239
240void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
241 SetIntRangeFn setResultRange) {
242 uint64_t max = kMaxSubgroupSize;
243 if (auto specified = getUpperBound())
244 max = specified->getZExtValue();
245 setResultRange(getResult(), getIndexRange(umin: 1, umax: max));
246}
247
248void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
249 SetIntRangeFn setResultRange) {
250 auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
251 Value idxResult) {
252 if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
253 return;
254 ConstantIntRanges dimRange =
255 argRange.intersection(other: getIndexRange(umin: 1, umax: kMaxDim));
256 setResultRange(dimResult, dimRange);
257 ConstantIntRanges idxRange =
258 getIndexRange(umin: 0, umax: dimRange.umax().getZExtValue() - 1);
259 setResultRange(idxResult, idxRange);
260 };
261
262 argRanges = argRanges.drop_front(N: getAsyncDependencies().size());
263 KernelDim3 gridDims = getGridSize();
264 KernelDim3 blockIds = getBlockIds();
265 setRange(argRanges[0], gridDims.x, blockIds.x);
266 setRange(argRanges[1], gridDims.y, blockIds.y);
267 setRange(argRanges[2], gridDims.z, blockIds.z);
268 KernelDim3 blockDims = getBlockSize();
269 KernelDim3 threadIds = getThreadIds();
270 setRange(argRanges[3], blockDims.x, threadIds.x);
271 setRange(argRanges[4], blockDims.y, threadIds.y);
272 setRange(argRanges[5], blockDims.z, threadIds.z);
273}
274

source code of mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp