1//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/StaticValueUtils.h"
10#include "mlir/IR/Matchers.h"
11#include "mlir/Support/LLVM.h"
12#include "llvm/ADT/APSInt.h"
13#include "llvm/ADT/STLExtras.h"
14#include "llvm/Support/MathExtras.h"
15
16namespace mlir {
17
18bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(ofr: v, value: 0); }
19
20bool isOneInteger(OpFoldResult v) { return isConstantIntValue(ofr: v, value: 1); }
21
22std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
23 SmallVector<OpFoldResult>>
24getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
25 SmallVector<OpFoldResult> offsets, sizes, strides;
26 offsets.reserve(N: ranges.size());
27 sizes.reserve(N: ranges.size());
28 strides.reserve(N: ranges.size());
29 for (const auto &[offset, size, stride] : ranges) {
30 offsets.push_back(Elt: offset);
31 sizes.push_back(Elt: size);
32 strides.push_back(Elt: stride);
33 }
34 return std::make_tuple(args&: offsets, args&: sizes, args&: strides);
35}
36
37/// Helper function to dispatch an OpFoldResult into `staticVec` if:
38/// a) it is an IntegerAttr
39/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
40/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
41/// `staticVec`. This is useful to extract mixed static and dynamic entries that
42/// come from an AttrSizedOperandSegments trait.
43void dispatchIndexOpFoldResult(OpFoldResult ofr,
44 SmallVectorImpl<Value> &dynamicVec,
45 SmallVectorImpl<int64_t> &staticVec) {
46 auto v = llvm::dyn_cast_if_present<Value>(Val&: ofr);
47 if (!v) {
48 APInt apInt = cast<IntegerAttr>(Val: cast<Attribute>(Val&: ofr)).getValue();
49 staticVec.push_back(Elt: apInt.getSExtValue());
50 return;
51 }
52 dynamicVec.push_back(Elt: v);
53 staticVec.push_back(Elt: ShapedType::kDynamic);
54}
55
56std::pair<int64_t, OpFoldResult>
57getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) {
58 int64_t tileSizeForShape =
59 getConstantIntValue(ofr: tileSizeOfr).value_or(u: ShapedType::kDynamic);
60
61 OpFoldResult tileSizeOfrSimplified =
62 (tileSizeForShape != ShapedType::kDynamic)
63 ? b.getIndexAttr(value: tileSizeForShape)
64 : tileSizeOfr;
65
66 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
67 tileSizeOfrSimplified);
68}
69
70void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
71 SmallVectorImpl<Value> &dynamicVec,
72 SmallVectorImpl<int64_t> &staticVec) {
73 for (OpFoldResult ofr : ofrs)
74 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
75}
76
77/// Given a value, try to extract a constant Attribute. If this fails, return
78/// the original value.
79OpFoldResult getAsOpFoldResult(Value val) {
80 if (!val)
81 return OpFoldResult();
82 Attribute attr;
83 if (matchPattern(value: val, pattern: m_Constant(bind_value: &attr)))
84 return attr;
85 return val;
86}
87
88/// Given an array of values, try to extract a constant Attribute from each
89/// value. If this fails, return the original value.
90SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) {
91 return llvm::to_vector(
92 Range: llvm::map_range(C&: values, F: [](Value v) { return getAsOpFoldResult(val: v); }));
93}
94
95/// Convert `arrayAttr` to a vector of OpFoldResult.
96SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
97 SmallVector<OpFoldResult> res;
98 res.reserve(N: arrayAttr.size());
99 for (Attribute a : arrayAttr)
100 res.push_back(Elt: a);
101 return res;
102}
103
104OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) {
105 return IntegerAttr::get(type: IndexType::get(context: ctx), value: val);
106}
107
108SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
109 ArrayRef<int64_t> values) {
110 return llvm::to_vector(Range: llvm::map_range(
111 C&: values, F: [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, val: v); }));
112}
113
114/// If ofr is a constant integer or an IntegerAttr, return the integer.
115std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
116 // Case 1: Check for Constant integer.
117 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) {
118 APSInt intVal;
119 if (matchPattern(value: val, pattern: m_ConstantInt(bind_value: &intVal)))
120 return intVal.getSExtValue();
121 return std::nullopt;
122 }
123 // Case 2: Check for IntegerAttr.
124 Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: ofr);
125 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(Val&: attr))
126 return intAttr.getValue().getSExtValue();
127 return std::nullopt;
128}
129
130std::optional<SmallVector<int64_t>>
131getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
132 bool failed = false;
133 SmallVector<int64_t> res = llvm::map_to_vector(C&: ofrs, F: [&](OpFoldResult ofr) {
134 auto cv = getConstantIntValue(ofr);
135 if (!cv.has_value())
136 failed = true;
137 return cv.value_or(u: 0);
138 });
139 if (failed)
140 return std::nullopt;
141 return res;
142}
143
144bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
145 return getConstantIntValue(ofr) == value;
146}
147
148bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
149 return llvm::all_of(
150 Range&: ofrs, P: [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
151}
152
153bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
154 ArrayRef<int64_t> values) {
155 if (ofrs.size() != values.size())
156 return false;
157 std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
158 return constOfrs && llvm::equal(LRange&: constOfrs.value(), RRange&: values);
159}
160
161/// Return true if ofr1 and ofr2 are the same integer constant attribute values
162/// or the same SSA value.
163/// Ignore integer bitwidth and type mismatch that come from the fact there is
164/// no IndexAttr and that IndexType has no bitwidth.
165bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
166 auto cst1 = getConstantIntValue(ofr: ofr1), cst2 = getConstantIntValue(ofr: ofr2);
167 if (cst1 && cst2 && *cst1 == *cst2)
168 return true;
169 auto v1 = llvm::dyn_cast_if_present<Value>(Val&: ofr1),
170 v2 = llvm::dyn_cast_if_present<Value>(Val&: ofr2);
171 return v1 && v1 == v2;
172}
173
174bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
175 ArrayRef<OpFoldResult> ofrs2) {
176 if (ofrs1.size() != ofrs2.size())
177 return false;
178 for (auto [ofr1, ofr2] : llvm::zip_equal(t&: ofrs1, u&: ofrs2))
179 if (!isEqualConstantIntOrValue(ofr1, ofr2))
180 return false;
181 return true;
182}
183
184/// Return a vector of OpFoldResults with the same size a staticValues, but all
185/// elements for which ShapedType::isDynamic is true, will be replaced by
186/// dynamicValues.
187SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
188 ValueRange dynamicValues,
189 MLIRContext *context) {
190 SmallVector<OpFoldResult> res;
191 res.reserve(N: staticValues.size());
192 unsigned numDynamic = 0;
193 unsigned count = static_cast<unsigned>(staticValues.size());
194 for (unsigned idx = 0; idx < count; ++idx) {
195 int64_t value = staticValues[idx];
196 res.push_back(Elt: ShapedType::isDynamic(dValue: value)
197 ? OpFoldResult{dynamicValues[numDynamic++]}
198 : OpFoldResult{IntegerAttr::get(
199 type: IntegerType::get(context, width: 64), value: staticValues[idx])});
200 }
201 return res;
202}
203SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
204 ValueRange dynamicValues, Builder &b) {
205 return getMixedValues(staticValues, dynamicValues, context: b.getContext());
206}
207
208/// Decompose a vector of mixed static or dynamic values into the corresponding
209/// pair of arrays. This is the inverse function of `getMixedValues`.
210std::pair<SmallVector<int64_t>, SmallVector<Value>>
211decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues) {
212 SmallVector<int64_t> staticValues;
213 SmallVector<Value> dynamicValues;
214 for (const auto &it : mixedValues) {
215 if (auto attr = dyn_cast<Attribute>(Val: it)) {
216 staticValues.push_back(Elt: cast<IntegerAttr>(Val&: attr).getInt());
217 } else {
218 staticValues.push_back(Elt: ShapedType::kDynamic);
219 dynamicValues.push_back(Elt: cast<Value>(Val: it));
220 }
221 }
222 return {staticValues, dynamicValues};
223}
224
225/// Helper to sort `values` according to matching `keys`.
226template <typename K, typename V>
227static SmallVector<V>
228getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values,
229 llvm::function_ref<bool(K, K)> compare) {
230 if (keys.empty())
231 return SmallVector<V>{values};
232 assert(keys.size() == values.size() && "unexpected mismatching sizes");
233 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
234 llvm::sort(indices,
235 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
236 SmallVector<V> res;
237 res.reserve(values.size());
238 for (int64_t i = 0, e = indices.size(); i < e; ++i)
239 res.push_back(values[indices[i]]);
240 return res;
241}
242
243SmallVector<Value>
244getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
245 llvm::function_ref<bool(Attribute, Attribute)> compare) {
246 return getValuesSortedByKeyImpl(keys, values, compare);
247}
248
249SmallVector<OpFoldResult>
250getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
251 llvm::function_ref<bool(Attribute, Attribute)> compare) {
252 return getValuesSortedByKeyImpl(keys, values, compare);
253}
254
255SmallVector<int64_t>
256getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
257 llvm::function_ref<bool(Attribute, Attribute)> compare) {
258 return getValuesSortedByKeyImpl(keys, values, compare);
259}
260
261/// Return the number of iterations for a loop with a lower bound `lb`, upper
262/// bound `ub` and step `step`.
263std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
264 OpFoldResult step) {
265 if (lb == ub)
266 return 0;
267
268 std::optional<int64_t> lbConstant = getConstantIntValue(ofr: lb);
269 if (!lbConstant)
270 return std::nullopt;
271 std::optional<int64_t> ubConstant = getConstantIntValue(ofr: ub);
272 if (!ubConstant)
273 return std::nullopt;
274 std::optional<int64_t> stepConstant = getConstantIntValue(ofr: step);
275 if (!stepConstant || *stepConstant == 0)
276 return std::nullopt;
277
278 return llvm::divideCeilSigned(Numerator: *ubConstant - *lbConstant, Denominator: *stepConstant);
279}
280
281bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
282 return llvm::none_of(Range&: sizesOrOffsets, P: [](int64_t value) {
283 return ShapedType::isStatic(dValue: value) && value < 0;
284 });
285}
286
287bool hasValidStrides(SmallVector<int64_t> strides) {
288 return llvm::none_of(Range&: strides, P: [](int64_t value) {
289 return ShapedType::isStatic(dValue: value) && value == 0;
290 });
291}
292
293LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
294 bool onlyNonNegative, bool onlyNonZero) {
295 bool valuesChanged = false;
296 for (OpFoldResult &ofr : ofrs) {
297 if (isa<Attribute>(Val: ofr))
298 continue;
299 Attribute attr;
300 if (matchPattern(value: cast<Value>(Val&: ofr), pattern: m_Constant(bind_value: &attr))) {
301 // Note: All ofrs have index type.
302 if (onlyNonNegative && *getConstantIntValue(ofr: attr) < 0)
303 continue;
304 if (onlyNonZero && *getConstantIntValue(ofr: attr) == 0)
305 continue;
306 ofr = attr;
307 valuesChanged = true;
308 }
309 }
310 return success(IsSuccess: valuesChanged);
311}
312
313LogicalResult
314foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
315 return foldDynamicIndexList(ofrs&: offsetsOrSizes, /*onlyNonNegative=*/true,
316 /*onlyNonZero=*/false);
317}
318
319LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
320 return foldDynamicIndexList(ofrs&: strides, /*onlyNonNegative=*/false,
321 /*onlyNonZero=*/true);
322}
323
324} // namespace mlir
325

source code of mlir/lib/Dialect/Utils/StaticValueUtils.cpp