1 | //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
10 | #include "mlir/IR/Matchers.h" |
11 | #include "mlir/Support/LLVM.h" |
12 | #include "mlir/Support/MathExtras.h" |
13 | #include "llvm/ADT/APSInt.h" |
14 | |
15 | namespace mlir { |
16 | |
17 | bool isZeroIndex(OpFoldResult v) { |
18 | if (!v) |
19 | return false; |
20 | std::optional<int64_t> constint = getConstantIntValue(ofr: v); |
21 | if (!constint) |
22 | return false; |
23 | return *constint == 0; |
24 | } |
25 | |
26 | std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, |
27 | SmallVector<OpFoldResult>> |
28 | getOffsetsSizesAndStrides(ArrayRef<Range> ranges) { |
29 | SmallVector<OpFoldResult> offsets, sizes, strides; |
30 | offsets.reserve(N: ranges.size()); |
31 | sizes.reserve(N: ranges.size()); |
32 | strides.reserve(N: ranges.size()); |
33 | for (const auto &[offset, size, stride] : ranges) { |
34 | offsets.push_back(Elt: offset); |
35 | sizes.push_back(Elt: size); |
36 | strides.push_back(Elt: stride); |
37 | } |
38 | return std::make_tuple(args&: offsets, args&: sizes, args&: strides); |
39 | } |
40 | |
41 | /// Helper function to dispatch an OpFoldResult into `staticVec` if: |
42 | /// a) it is an IntegerAttr |
43 | /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. |
44 | /// In such dynamic cases, a copy of the `sentinel` value is also pushed to |
45 | /// `staticVec`. This is useful to extract mixed static and dynamic entries that |
46 | /// come from an AttrSizedOperandSegments trait. |
47 | void dispatchIndexOpFoldResult(OpFoldResult ofr, |
48 | SmallVectorImpl<Value> &dynamicVec, |
49 | SmallVectorImpl<int64_t> &staticVec) { |
50 | auto v = llvm::dyn_cast_if_present<Value>(Val&: ofr); |
51 | if (!v) { |
52 | APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue(); |
53 | staticVec.push_back(Elt: apInt.getSExtValue()); |
54 | return; |
55 | } |
56 | dynamicVec.push_back(Elt: v); |
57 | staticVec.push_back(ShapedType::kDynamic); |
58 | } |
59 | |
60 | void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, |
61 | SmallVectorImpl<Value> &dynamicVec, |
62 | SmallVectorImpl<int64_t> &staticVec) { |
63 | for (OpFoldResult ofr : ofrs) |
64 | dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec); |
65 | } |
66 | |
67 | /// Given a value, try to extract a constant Attribute. If this fails, return |
68 | /// the original value. |
69 | OpFoldResult getAsOpFoldResult(Value val) { |
70 | if (!val) |
71 | return OpFoldResult(); |
72 | Attribute attr; |
73 | if (matchPattern(value: val, pattern: m_Constant(bind_value: &attr))) |
74 | return attr; |
75 | return val; |
76 | } |
77 | |
78 | /// Given an array of values, try to extract a constant Attribute from each |
79 | /// value. If this fails, return the original value. |
80 | SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) { |
81 | return llvm::to_vector( |
82 | Range: llvm::map_range(C&: values, F: [](Value v) { return getAsOpFoldResult(val: v); })); |
83 | } |
84 | |
85 | /// Convert `arrayAttr` to a vector of OpFoldResult. |
86 | SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) { |
87 | SmallVector<OpFoldResult> res; |
88 | res.reserve(N: arrayAttr.size()); |
89 | for (Attribute a : arrayAttr) |
90 | res.push_back(a); |
91 | return res; |
92 | } |
93 | |
94 | OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) { |
95 | return IntegerAttr::get(IndexType::get(ctx), val); |
96 | } |
97 | |
98 | SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx, |
99 | ArrayRef<int64_t> values) { |
100 | return llvm::to_vector(Range: llvm::map_range( |
101 | C&: values, F: [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, val: v); })); |
102 | } |
103 | |
104 | /// If ofr is a constant integer or an IntegerAttr, return the integer. |
105 | std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { |
106 | // Case 1: Check for Constant integer. |
107 | if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) { |
108 | APSInt intVal; |
109 | if (matchPattern(val, m_ConstantInt(&intVal))) |
110 | return intVal.getSExtValue(); |
111 | return std::nullopt; |
112 | } |
113 | // Case 2: Check for IntegerAttr. |
114 | Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: ofr); |
115 | if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr)) |
116 | return intAttr.getValue().getSExtValue(); |
117 | return std::nullopt; |
118 | } |
119 | |
120 | std::optional<SmallVector<int64_t>> |
121 | getConstantIntValues(ArrayRef<OpFoldResult> ofrs) { |
122 | bool failed = false; |
123 | SmallVector<int64_t> res = llvm::map_to_vector(C&: ofrs, F: [&](OpFoldResult ofr) { |
124 | auto cv = getConstantIntValue(ofr); |
125 | if (!cv.has_value()) |
126 | failed = true; |
127 | return cv.has_value() ? cv.value() : 0; |
128 | }); |
129 | if (failed) |
130 | return std::nullopt; |
131 | return res; |
132 | } |
133 | |
134 | /// Return true if `ofr` is constant integer equal to `value`. |
135 | bool isConstantIntValue(OpFoldResult ofr, int64_t value) { |
136 | auto val = getConstantIntValue(ofr); |
137 | return val && *val == value; |
138 | } |
139 | |
140 | /// Return true if ofr1 and ofr2 are the same integer constant attribute values |
141 | /// or the same SSA value. |
142 | /// Ignore integer bitwidth and type mismatch that come from the fact there is |
143 | /// no IndexAttr and that IndexType has no bitwidth. |
144 | bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { |
145 | auto cst1 = getConstantIntValue(ofr: ofr1), cst2 = getConstantIntValue(ofr: ofr2); |
146 | if (cst1 && cst2 && *cst1 == *cst2) |
147 | return true; |
148 | auto v1 = llvm::dyn_cast_if_present<Value>(Val&: ofr1), |
149 | v2 = llvm::dyn_cast_if_present<Value>(Val&: ofr2); |
150 | return v1 && v1 == v2; |
151 | } |
152 | |
153 | bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1, |
154 | ArrayRef<OpFoldResult> ofrs2) { |
155 | if (ofrs1.size() != ofrs2.size()) |
156 | return false; |
157 | for (auto [ofr1, ofr2] : llvm::zip_equal(t&: ofrs1, u&: ofrs2)) |
158 | if (!isEqualConstantIntOrValue(ofr1, ofr2)) |
159 | return false; |
160 | return true; |
161 | } |
162 | |
163 | /// Return a vector of OpFoldResults with the same size a staticValues, but all |
164 | /// elements for which ShapedType::isDynamic is true, will be replaced by |
165 | /// dynamicValues. |
166 | SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, |
167 | ValueRange dynamicValues, Builder &b) { |
168 | SmallVector<OpFoldResult> res; |
169 | res.reserve(N: staticValues.size()); |
170 | unsigned numDynamic = 0; |
171 | unsigned count = static_cast<unsigned>(staticValues.size()); |
172 | for (unsigned idx = 0; idx < count; ++idx) { |
173 | int64_t value = staticValues[idx]; |
174 | res.push_back(ShapedType::isDynamic(value) |
175 | ? OpFoldResult{dynamicValues[numDynamic++]} |
176 | : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])}); |
177 | } |
178 | return res; |
179 | } |
180 | |
181 | /// Decompose a vector of mixed static or dynamic values into the corresponding |
182 | /// pair of arrays. This is the inverse function of `getMixedValues`. |
183 | std::pair<ArrayAttr, SmallVector<Value>> |
184 | decomposeMixedValues(Builder &b, |
185 | const SmallVectorImpl<OpFoldResult> &mixedValues) { |
186 | SmallVector<int64_t> staticValues; |
187 | SmallVector<Value> dynamicValues; |
188 | for (const auto &it : mixedValues) { |
189 | if (it.is<Attribute>()) { |
190 | staticValues.push_back(Elt: cast<IntegerAttr>(it.get<Attribute>()).getInt()); |
191 | } else { |
192 | staticValues.push_back(ShapedType::kDynamic); |
193 | dynamicValues.push_back(Elt: it.get<Value>()); |
194 | } |
195 | } |
196 | return {b.getI64ArrayAttr(staticValues), dynamicValues}; |
197 | } |
198 | |
199 | /// Helper to sort `values` according to matching `keys`. |
200 | template <typename K, typename V> |
201 | static SmallVector<V> |
202 | getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values, |
203 | llvm::function_ref<bool(K, K)> compare) { |
204 | if (keys.empty()) |
205 | return SmallVector<V>{values}; |
206 | assert(keys.size() == values.size() && "unexpected mismatching sizes" ); |
207 | auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size())); |
208 | std::sort(indices.begin(), indices.end(), |
209 | [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); }); |
210 | SmallVector<V> res; |
211 | res.reserve(values.size()); |
212 | for (int64_t i = 0, e = indices.size(); i < e; ++i) |
213 | res.push_back(values[indices[i]]); |
214 | return res; |
215 | } |
216 | |
217 | SmallVector<Value> |
218 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values, |
219 | llvm::function_ref<bool(Attribute, Attribute)> compare) { |
220 | return getValuesSortedByKeyImpl(keys, values, compare); |
221 | } |
222 | |
223 | SmallVector<OpFoldResult> |
224 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values, |
225 | llvm::function_ref<bool(Attribute, Attribute)> compare) { |
226 | return getValuesSortedByKeyImpl(keys, values, compare); |
227 | } |
228 | |
229 | SmallVector<int64_t> |
230 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values, |
231 | llvm::function_ref<bool(Attribute, Attribute)> compare) { |
232 | return getValuesSortedByKeyImpl(keys, values, compare); |
233 | } |
234 | |
235 | /// Return the number of iterations for a loop with a lower bound `lb`, upper |
236 | /// bound `ub` and step `step`. |
237 | std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, |
238 | OpFoldResult step) { |
239 | if (lb == ub) |
240 | return 0; |
241 | |
242 | std::optional<int64_t> lbConstant = getConstantIntValue(ofr: lb); |
243 | if (!lbConstant) |
244 | return std::nullopt; |
245 | std::optional<int64_t> ubConstant = getConstantIntValue(ofr: ub); |
246 | if (!ubConstant) |
247 | return std::nullopt; |
248 | std::optional<int64_t> stepConstant = getConstantIntValue(ofr: step); |
249 | if (!stepConstant) |
250 | return std::nullopt; |
251 | |
252 | return mlir::ceilDiv(lhs: *ubConstant - *lbConstant, rhs: *stepConstant); |
253 | } |
254 | |
255 | bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) { |
256 | return llvm::none_of(Range&: sizesOrOffsets, P: [](int64_t value) { |
257 | return !ShapedType::isDynamic(value) && value < 0; |
258 | }); |
259 | } |
260 | |
261 | bool hasValidStrides(SmallVector<int64_t> strides) { |
262 | return llvm::none_of(Range&: strides, P: [](int64_t value) { |
263 | return !ShapedType::isDynamic(value) && value == 0; |
264 | }); |
265 | } |
266 | |
267 | LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, |
268 | bool onlyNonNegative, bool onlyNonZero) { |
269 | bool valuesChanged = false; |
270 | for (OpFoldResult &ofr : ofrs) { |
271 | if (ofr.is<Attribute>()) |
272 | continue; |
273 | Attribute attr; |
274 | if (matchPattern(value: ofr.get<Value>(), pattern: m_Constant(bind_value: &attr))) { |
275 | // Note: All ofrs have index type. |
276 | if (onlyNonNegative && *getConstantIntValue(ofr: attr) < 0) |
277 | continue; |
278 | if (onlyNonZero && *getConstantIntValue(ofr: attr) == 0) |
279 | continue; |
280 | ofr = attr; |
281 | valuesChanged = true; |
282 | } |
283 | } |
284 | return success(isSuccess: valuesChanged); |
285 | } |
286 | |
287 | LogicalResult |
288 | foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) { |
289 | return foldDynamicIndexList(ofrs&: offsetsOrSizes, /*onlyNonNegative=*/true, |
290 | /*onlyNonZero=*/false); |
291 | } |
292 | |
293 | LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) { |
294 | return foldDynamicIndexList(ofrs&: strides, /*onlyNonNegative=*/false, |
295 | /*onlyNonZero=*/true); |
296 | } |
297 | |
298 | } // namespace mlir |
299 | |