1 | //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
10 | #include "mlir/IR/Matchers.h" |
11 | #include "mlir/Support/LLVM.h" |
12 | #include "llvm/ADT/APSInt.h" |
13 | #include "llvm/ADT/STLExtras.h" |
14 | #include "llvm/Support/MathExtras.h" |
15 | |
16 | namespace mlir { |
17 | |
18 | bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(ofr: v, value: 0); } |
19 | |
20 | bool isOneInteger(OpFoldResult v) { return isConstantIntValue(ofr: v, value: 1); } |
21 | |
22 | std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, |
23 | SmallVector<OpFoldResult>> |
24 | getOffsetsSizesAndStrides(ArrayRef<Range> ranges) { |
25 | SmallVector<OpFoldResult> offsets, sizes, strides; |
26 | offsets.reserve(N: ranges.size()); |
27 | sizes.reserve(N: ranges.size()); |
28 | strides.reserve(N: ranges.size()); |
29 | for (const auto &[offset, size, stride] : ranges) { |
30 | offsets.push_back(Elt: offset); |
31 | sizes.push_back(Elt: size); |
32 | strides.push_back(Elt: stride); |
33 | } |
34 | return std::make_tuple(args&: offsets, args&: sizes, args&: strides); |
35 | } |
36 | |
37 | /// Helper function to dispatch an OpFoldResult into `staticVec` if: |
38 | /// a) it is an IntegerAttr |
39 | /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. |
40 | /// In such dynamic cases, a copy of the `sentinel` value is also pushed to |
41 | /// `staticVec`. This is useful to extract mixed static and dynamic entries that |
42 | /// come from an AttrSizedOperandSegments trait. |
43 | void dispatchIndexOpFoldResult(OpFoldResult ofr, |
44 | SmallVectorImpl<Value> &dynamicVec, |
45 | SmallVectorImpl<int64_t> &staticVec) { |
46 | auto v = llvm::dyn_cast_if_present<Value>(Val&: ofr); |
47 | if (!v) { |
48 | APInt apInt = cast<IntegerAttr>(cast<Attribute>(Val&: ofr)).getValue(); |
49 | staticVec.push_back(Elt: apInt.getSExtValue()); |
50 | return; |
51 | } |
52 | dynamicVec.push_back(Elt: v); |
53 | staticVec.push_back(ShapedType::kDynamic); |
54 | } |
55 | |
56 | std::pair<int64_t, OpFoldResult> |
57 | getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) { |
58 | int64_t tileSizeForShape = |
59 | getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic); |
60 | |
61 | OpFoldResult tileSizeOfrSimplified = |
62 | (tileSizeForShape != ShapedType::kDynamic) |
63 | ? b.getIndexAttr(tileSizeForShape) |
64 | : tileSizeOfr; |
65 | |
66 | return std::pair<int64_t, OpFoldResult>(tileSizeForShape, |
67 | tileSizeOfrSimplified); |
68 | } |
69 | |
70 | void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, |
71 | SmallVectorImpl<Value> &dynamicVec, |
72 | SmallVectorImpl<int64_t> &staticVec) { |
73 | for (OpFoldResult ofr : ofrs) |
74 | dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec); |
75 | } |
76 | |
77 | /// Given a value, try to extract a constant Attribute. If this fails, return |
78 | /// the original value. |
79 | OpFoldResult getAsOpFoldResult(Value val) { |
80 | if (!val) |
81 | return OpFoldResult(); |
82 | Attribute attr; |
83 | if (matchPattern(value: val, pattern: m_Constant(bind_value: &attr))) |
84 | return attr; |
85 | return val; |
86 | } |
87 | |
88 | /// Given an array of values, try to extract a constant Attribute from each |
89 | /// value. If this fails, return the original value. |
90 | SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) { |
91 | return llvm::to_vector( |
92 | Range: llvm::map_range(C&: values, F: [](Value v) { return getAsOpFoldResult(val: v); })); |
93 | } |
94 | |
95 | /// Convert `arrayAttr` to a vector of OpFoldResult. |
96 | SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) { |
97 | SmallVector<OpFoldResult> res; |
98 | res.reserve(N: arrayAttr.size()); |
99 | for (Attribute a : arrayAttr) |
100 | res.push_back(a); |
101 | return res; |
102 | } |
103 | |
104 | OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) { |
105 | return IntegerAttr::get(IndexType::get(ctx), val); |
106 | } |
107 | |
108 | SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx, |
109 | ArrayRef<int64_t> values) { |
110 | return llvm::to_vector(Range: llvm::map_range( |
111 | C&: values, F: [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, val: v); })); |
112 | } |
113 | |
114 | /// If ofr is a constant integer or an IntegerAttr, return the integer. |
115 | std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { |
116 | // Case 1: Check for Constant integer. |
117 | if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) { |
118 | APSInt intVal; |
119 | if (matchPattern(val, m_ConstantInt(&intVal))) |
120 | return intVal.getSExtValue(); |
121 | return std::nullopt; |
122 | } |
123 | // Case 2: Check for IntegerAttr. |
124 | Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: ofr); |
125 | if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr)) |
126 | return intAttr.getValue().getSExtValue(); |
127 | return std::nullopt; |
128 | } |
129 | |
130 | std::optional<SmallVector<int64_t>> |
131 | getConstantIntValues(ArrayRef<OpFoldResult> ofrs) { |
132 | bool failed = false; |
133 | SmallVector<int64_t> res = llvm::map_to_vector(C&: ofrs, F: [&](OpFoldResult ofr) { |
134 | auto cv = getConstantIntValue(ofr); |
135 | if (!cv.has_value()) |
136 | failed = true; |
137 | return cv.value_or(u: 0); |
138 | }); |
139 | if (failed) |
140 | return std::nullopt; |
141 | return res; |
142 | } |
143 | |
144 | bool isConstantIntValue(OpFoldResult ofr, int64_t value) { |
145 | auto val = getConstantIntValue(ofr); |
146 | return val && *val == value; |
147 | } |
148 | |
149 | bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) { |
150 | return llvm::all_of( |
151 | Range&: ofrs, P: [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); }); |
152 | } |
153 | |
154 | bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs, |
155 | ArrayRef<int64_t> values) { |
156 | if (ofrs.size() != values.size()) |
157 | return false; |
158 | std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs); |
159 | return constOfrs && llvm::equal(LRange&: constOfrs.value(), RRange&: values); |
160 | } |
161 | |
162 | /// Return true if ofr1 and ofr2 are the same integer constant attribute values |
163 | /// or the same SSA value. |
164 | /// Ignore integer bitwidth and type mismatch that come from the fact there is |
165 | /// no IndexAttr and that IndexType has no bitwidth. |
166 | bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { |
167 | auto cst1 = getConstantIntValue(ofr: ofr1), cst2 = getConstantIntValue(ofr: ofr2); |
168 | if (cst1 && cst2 && *cst1 == *cst2) |
169 | return true; |
170 | auto v1 = llvm::dyn_cast_if_present<Value>(Val&: ofr1), |
171 | v2 = llvm::dyn_cast_if_present<Value>(Val&: ofr2); |
172 | return v1 && v1 == v2; |
173 | } |
174 | |
175 | bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1, |
176 | ArrayRef<OpFoldResult> ofrs2) { |
177 | if (ofrs1.size() != ofrs2.size()) |
178 | return false; |
179 | for (auto [ofr1, ofr2] : llvm::zip_equal(t&: ofrs1, u&: ofrs2)) |
180 | if (!isEqualConstantIntOrValue(ofr1, ofr2)) |
181 | return false; |
182 | return true; |
183 | } |
184 | |
185 | /// Return a vector of OpFoldResults with the same size a staticValues, but all |
186 | /// elements for which ShapedType::isDynamic is true, will be replaced by |
187 | /// dynamicValues. |
188 | SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, |
189 | ValueRange dynamicValues, |
190 | MLIRContext *context) { |
191 | SmallVector<OpFoldResult> res; |
192 | res.reserve(N: staticValues.size()); |
193 | unsigned numDynamic = 0; |
194 | unsigned count = static_cast<unsigned>(staticValues.size()); |
195 | for (unsigned idx = 0; idx < count; ++idx) { |
196 | int64_t value = staticValues[idx]; |
197 | res.push_back(ShapedType::isDynamic(value) |
198 | ? OpFoldResult{dynamicValues[numDynamic++]} |
199 | : OpFoldResult{IntegerAttr::get( |
200 | IntegerType::get(context, 64), staticValues[idx])}); |
201 | } |
202 | return res; |
203 | } |
204 | SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, |
205 | ValueRange dynamicValues, Builder &b) { |
206 | return getMixedValues(staticValues, dynamicValues, context: b.getContext()); |
207 | } |
208 | |
209 | /// Decompose a vector of mixed static or dynamic values into the corresponding |
210 | /// pair of arrays. This is the inverse function of `getMixedValues`. |
211 | std::pair<SmallVector<int64_t>, SmallVector<Value>> |
212 | decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) { |
213 | SmallVector<int64_t> staticValues; |
214 | SmallVector<Value> dynamicValues; |
215 | for (const auto &it : mixedValues) { |
216 | if (auto attr = dyn_cast<Attribute>(Val: it)) { |
217 | staticValues.push_back(Elt: cast<IntegerAttr>(attr).getInt()); |
218 | } else { |
219 | staticValues.push_back(ShapedType::kDynamic); |
220 | dynamicValues.push_back(Elt: cast<Value>(Val: it)); |
221 | } |
222 | } |
223 | return {staticValues, dynamicValues}; |
224 | } |
225 | |
226 | /// Helper to sort `values` according to matching `keys`. |
227 | template <typename K, typename V> |
228 | static SmallVector<V> |
229 | getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values, |
230 | llvm::function_ref<bool(K, K)> compare) { |
231 | if (keys.empty()) |
232 | return SmallVector<V>{values}; |
233 | assert(keys.size() == values.size() && "unexpected mismatching sizes"); |
234 | auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size())); |
235 | llvm::sort(indices, |
236 | [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); }); |
237 | SmallVector<V> res; |
238 | res.reserve(values.size()); |
239 | for (int64_t i = 0, e = indices.size(); i < e; ++i) |
240 | res.push_back(values[indices[i]]); |
241 | return res; |
242 | } |
243 | |
244 | SmallVector<Value> |
245 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values, |
246 | llvm::function_ref<bool(Attribute, Attribute)> compare) { |
247 | return getValuesSortedByKeyImpl(keys, values, compare); |
248 | } |
249 | |
250 | SmallVector<OpFoldResult> |
251 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values, |
252 | llvm::function_ref<bool(Attribute, Attribute)> compare) { |
253 | return getValuesSortedByKeyImpl(keys, values, compare); |
254 | } |
255 | |
256 | SmallVector<int64_t> |
257 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values, |
258 | llvm::function_ref<bool(Attribute, Attribute)> compare) { |
259 | return getValuesSortedByKeyImpl(keys, values, compare); |
260 | } |
261 | |
262 | /// Return the number of iterations for a loop with a lower bound `lb`, upper |
263 | /// bound `ub` and step `step`. |
264 | std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, |
265 | OpFoldResult step) { |
266 | if (lb == ub) |
267 | return 0; |
268 | |
269 | std::optional<int64_t> lbConstant = getConstantIntValue(ofr: lb); |
270 | if (!lbConstant) |
271 | return std::nullopt; |
272 | std::optional<int64_t> ubConstant = getConstantIntValue(ofr: ub); |
273 | if (!ubConstant) |
274 | return std::nullopt; |
275 | std::optional<int64_t> stepConstant = getConstantIntValue(ofr: step); |
276 | if (!stepConstant) |
277 | return std::nullopt; |
278 | |
279 | return llvm::divideCeilSigned(Numerator: *ubConstant - *lbConstant, Denominator: *stepConstant); |
280 | } |
281 | |
282 | bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) { |
283 | return llvm::none_of(Range&: sizesOrOffsets, P: [](int64_t value) { |
284 | return !ShapedType::isDynamic(value) && value < 0; |
285 | }); |
286 | } |
287 | |
288 | bool hasValidStrides(SmallVector<int64_t> strides) { |
289 | return llvm::none_of(Range&: strides, P: [](int64_t value) { |
290 | return !ShapedType::isDynamic(value) && value == 0; |
291 | }); |
292 | } |
293 | |
294 | LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, |
295 | bool onlyNonNegative, bool onlyNonZero) { |
296 | bool valuesChanged = false; |
297 | for (OpFoldResult &ofr : ofrs) { |
298 | if (isa<Attribute>(Val: ofr)) |
299 | continue; |
300 | Attribute attr; |
301 | if (matchPattern(value: cast<Value>(Val&: ofr), pattern: m_Constant(bind_value: &attr))) { |
302 | // Note: All ofrs have index type. |
303 | if (onlyNonNegative && *getConstantIntValue(ofr: attr) < 0) |
304 | continue; |
305 | if (onlyNonZero && *getConstantIntValue(ofr: attr) == 0) |
306 | continue; |
307 | ofr = attr; |
308 | valuesChanged = true; |
309 | } |
310 | } |
311 | return success(IsSuccess: valuesChanged); |
312 | } |
313 | |
314 | LogicalResult |
315 | foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) { |
316 | return foldDynamicIndexList(ofrs&: offsetsOrSizes, /*onlyNonNegative=*/true, |
317 | /*onlyNonZero=*/false); |
318 | } |
319 | |
320 | LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) { |
321 | return foldDynamicIndexList(ofrs&: strides, /*onlyNonNegative=*/false, |
322 | /*onlyNonZero=*/true); |
323 | } |
324 | |
325 | } // namespace mlir |
326 |
Definitions
- isZeroInteger
- isOneInteger
- getOffsetsSizesAndStrides
- dispatchIndexOpFoldResult
- getSimplifiedOfrAndStaticSizePair
- dispatchIndexOpFoldResults
- getAsOpFoldResult
- getAsOpFoldResult
- getAsOpFoldResult
- getAsIndexOpFoldResult
- getAsIndexOpFoldResult
- getConstantIntValue
- getConstantIntValues
- isConstantIntValue
- areAllConstantIntValue
- areConstantIntValues
- isEqualConstantIntOrValue
- isEqualConstantIntOrValueArray
- getMixedValues
- getMixedValues
- decomposeMixedValues
- getValuesSortedByKeyImpl
- getValuesSortedByKey
- getValuesSortedByKey
- getValuesSortedByKey
- constantTripCount
- hasValidSizesOffsets
- hasValidStrides
- foldDynamicIndexList
- foldDynamicOffsetSizeList
Improve your Profiling and Debugging skills
Find out more