1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/GPU/IR/GPUDialect.h"
10#include "mlir/IR/Matchers.h"
11#include "mlir/Interfaces/InferIntRangeInterface.h"
12#include "llvm/ADT/STLForwardCompat.h"
13#include "llvm/Support/ErrorHandling.h"
14#include "llvm/Support/MathExtras.h"
15#include <optional>
16
17using namespace mlir;
18using namespace mlir::gpu;
19
20// Maximum grid and block dimensions of all known GPUs are less than 2^32.
21static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
22// Maximum cluster size
23static constexpr uint64_t kMaxClusterDim = 8;
24// Maximum subgroups are no larger than 128.
25static constexpr uint64_t kMaxSubgroupSize = 128;
26
27static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
28 unsigned width = IndexType::kInternalStorageBitWidth;
29 return ConstantIntRanges::fromUnsigned(umin: APInt(width, umin),
30 umax: APInt(width, umax));
31}
32
33namespace {
34enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
35} // end namespace
36
37/// If the operation `op` is in a context that is annotated with maximum
38/// launch dimensions (a launch op with constant block or grid
39/// sizes or a launch_func op with the appropriate dimensions), return
40/// the bound on the maximum size of the dimension that the op is querying.
41/// IDs will be one less than this bound.
42
43static Value valueByDim(KernelDim3 dims, Dimension dim) {
44 switch (dim) {
45 case Dimension::x:
46 return dims.x;
47 case Dimension::y:
48 return dims.y;
49 case Dimension::z:
50 return dims.z;
51 }
52 llvm_unreachable("All dimension enum cases handled above");
53}
54
55static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
56
57template <typename Op>
58static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
59 Dimension dim = op.getDimension();
60 if (auto launch = op->template getParentOfType<LaunchOp>()) {
61 KernelDim3 bounds;
62 switch (type) {
63 case LaunchDims::Block:
64 bounds = launch.getBlockSizeOperandValues();
65 break;
66 case LaunchDims::Grid:
67 bounds = launch.getGridSizeOperandValues();
68 break;
69 }
70 Value maybeBound = valueByDim(bounds, dim);
71 APInt value;
72 if (matchPattern(maybeBound, m_ConstantInt(&value)))
73 return value.getZExtValue();
74 }
75
76 if (auto func = op->template getParentOfType<GPUFuncOp>()) {
77 switch (type) {
78 case LaunchDims::Block:
79 return llvm::transformOptional(func.getKnownBlockSize(dim), zext);
80 case LaunchDims::Grid:
81 return llvm::transformOptional(func.getKnownGridSize(dim), zext);
82 }
83 }
84 return std::nullopt;
85}
86
87void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
88 SetIntRangeFn setResultRange) {
89 setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
90}
91
92void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
93 SetIntRangeFn setResultRange) {
94 uint64_t max = kMaxClusterDim;
95 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
96}
97
98void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
99 SetIntRangeFn setResultRange) {
100 std::optional<uint64_t> knownVal =
101 getKnownLaunchDim(*this, LaunchDims::Block);
102 if (knownVal)
103 setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
104 else
105 setResultRange(getResult(), getIndexRange(1, kMaxDim));
106}
107
108void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
109 SetIntRangeFn setResultRange) {
110 uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
111 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
112}
113
114void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
115 SetIntRangeFn setResultRange) {
116 std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
117 if (knownVal)
118 setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
119 else
120 setResultRange(getResult(), getIndexRange(1, kMaxDim));
121}
122
123void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
124 SetIntRangeFn setResultRange) {
125 uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
126 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
127}
128
129void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
130 SetIntRangeFn setResultRange) {
131 setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
132}
133
134void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
135 SetIntRangeFn setResultRange) {
136 setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
137}
138
139void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
140 SetIntRangeFn setResultRange) {
141 uint64_t blockDimMax =
142 getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
143 uint64_t gridDimMax =
144 getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
145 setResultRange(getResult(),
146 getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
147}
148
149void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
150 SetIntRangeFn setResultRange) {
151 setResultRange(getResult(), getIndexRange(1, kMaxDim));
152}
153
154void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
155 SetIntRangeFn setResultRange) {
156 setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
157}
158
159void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160 SetIntRangeFn setResultRange) {
161 auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
162 Value idxResult) {
163 if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
164 return;
165 ConstantIntRanges dimRange =
166 argRange.intersection(getIndexRange(1, kMaxDim));
167 setResultRange(dimResult, dimRange);
168 ConstantIntRanges idxRange =
169 getIndexRange(0, dimRange.umax().getZExtValue() - 1);
170 setResultRange(idxResult, idxRange);
171 };
172
173 argRanges = argRanges.drop_front(getAsyncDependencies().size());
174 KernelDim3 gridDims = getGridSize();
175 KernelDim3 blockIds = getBlockIds();
176 setRange(argRanges[0], gridDims.x, blockIds.x);
177 setRange(argRanges[1], gridDims.y, blockIds.y);
178 setRange(argRanges[2], gridDims.z, blockIds.z);
179 KernelDim3 blockDims = getBlockSize();
180 KernelDim3 threadIds = getThreadIds();
181 setRange(argRanges[3], blockDims.x, threadIds.x);
182 setRange(argRanges[4], blockDims.y, threadIds.y);
183 setRange(argRanges[5], blockDims.z, threadIds.z);
184}
185

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp