1//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/StaticValueUtils.h"
10#include "mlir/IR/Matchers.h"
11#include "mlir/Support/LLVM.h"
12#include "mlir/Support/MathExtras.h"
13#include "llvm/ADT/APSInt.h"
14
15namespace mlir {
16
17bool isZeroIndex(OpFoldResult v) {
18 if (!v)
19 return false;
20 std::optional<int64_t> constint = getConstantIntValue(ofr: v);
21 if (!constint)
22 return false;
23 return *constint == 0;
24}
25
26std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
27 SmallVector<OpFoldResult>>
28getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
29 SmallVector<OpFoldResult> offsets, sizes, strides;
30 offsets.reserve(N: ranges.size());
31 sizes.reserve(N: ranges.size());
32 strides.reserve(N: ranges.size());
33 for (const auto &[offset, size, stride] : ranges) {
34 offsets.push_back(Elt: offset);
35 sizes.push_back(Elt: size);
36 strides.push_back(Elt: stride);
37 }
38 return std::make_tuple(args&: offsets, args&: sizes, args&: strides);
39}
40
41/// Helper function to dispatch an OpFoldResult into `staticVec` if:
42/// a) it is an IntegerAttr
43/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
44/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
45/// `staticVec`. This is useful to extract mixed static and dynamic entries that
46/// come from an AttrSizedOperandSegments trait.
47void dispatchIndexOpFoldResult(OpFoldResult ofr,
48 SmallVectorImpl<Value> &dynamicVec,
49 SmallVectorImpl<int64_t> &staticVec) {
50 auto v = llvm::dyn_cast_if_present<Value>(Val&: ofr);
51 if (!v) {
52 APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
53 staticVec.push_back(Elt: apInt.getSExtValue());
54 return;
55 }
56 dynamicVec.push_back(Elt: v);
57 staticVec.push_back(ShapedType::kDynamic);
58}
59
60void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
61 SmallVectorImpl<Value> &dynamicVec,
62 SmallVectorImpl<int64_t> &staticVec) {
63 for (OpFoldResult ofr : ofrs)
64 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
65}
66
67/// Given a value, try to extract a constant Attribute. If this fails, return
68/// the original value.
69OpFoldResult getAsOpFoldResult(Value val) {
70 if (!val)
71 return OpFoldResult();
72 Attribute attr;
73 if (matchPattern(value: val, pattern: m_Constant(bind_value: &attr)))
74 return attr;
75 return val;
76}
77
78/// Given an array of values, try to extract a constant Attribute from each
79/// value. If this fails, return the original value.
80SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) {
81 return llvm::to_vector(
82 Range: llvm::map_range(C&: values, F: [](Value v) { return getAsOpFoldResult(val: v); }));
83}
84
85/// Convert `arrayAttr` to a vector of OpFoldResult.
86SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
87 SmallVector<OpFoldResult> res;
88 res.reserve(N: arrayAttr.size());
89 for (Attribute a : arrayAttr)
90 res.push_back(a);
91 return res;
92}
93
94OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) {
95 return IntegerAttr::get(IndexType::get(ctx), val);
96}
97
98SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
99 ArrayRef<int64_t> values) {
100 return llvm::to_vector(Range: llvm::map_range(
101 C&: values, F: [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, val: v); }));
102}
103
104/// If ofr is a constant integer or an IntegerAttr, return the integer.
105std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
106 // Case 1: Check for Constant integer.
107 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) {
108 APSInt intVal;
109 if (matchPattern(val, m_ConstantInt(&intVal)))
110 return intVal.getSExtValue();
111 return std::nullopt;
112 }
113 // Case 2: Check for IntegerAttr.
114 Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: ofr);
115 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
116 return intAttr.getValue().getSExtValue();
117 return std::nullopt;
118}
119
120std::optional<SmallVector<int64_t>>
121getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
122 bool failed = false;
123 SmallVector<int64_t> res = llvm::map_to_vector(C&: ofrs, F: [&](OpFoldResult ofr) {
124 auto cv = getConstantIntValue(ofr);
125 if (!cv.has_value())
126 failed = true;
127 return cv.has_value() ? cv.value() : 0;
128 });
129 if (failed)
130 return std::nullopt;
131 return res;
132}
133
134/// Return true if `ofr` is constant integer equal to `value`.
135bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
136 auto val = getConstantIntValue(ofr);
137 return val && *val == value;
138}
139
140/// Return true if ofr1 and ofr2 are the same integer constant attribute values
141/// or the same SSA value.
142/// Ignore integer bitwidth and type mismatch that come from the fact there is
143/// no IndexAttr and that IndexType has no bitwidth.
144bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
145 auto cst1 = getConstantIntValue(ofr: ofr1), cst2 = getConstantIntValue(ofr: ofr2);
146 if (cst1 && cst2 && *cst1 == *cst2)
147 return true;
148 auto v1 = llvm::dyn_cast_if_present<Value>(Val&: ofr1),
149 v2 = llvm::dyn_cast_if_present<Value>(Val&: ofr2);
150 return v1 && v1 == v2;
151}
152
153bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
154 ArrayRef<OpFoldResult> ofrs2) {
155 if (ofrs1.size() != ofrs2.size())
156 return false;
157 for (auto [ofr1, ofr2] : llvm::zip_equal(t&: ofrs1, u&: ofrs2))
158 if (!isEqualConstantIntOrValue(ofr1, ofr2))
159 return false;
160 return true;
161}
162
163/// Return a vector of OpFoldResults with the same size a staticValues, but all
164/// elements for which ShapedType::isDynamic is true, will be replaced by
165/// dynamicValues.
166SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
167 ValueRange dynamicValues, Builder &b) {
168 SmallVector<OpFoldResult> res;
169 res.reserve(N: staticValues.size());
170 unsigned numDynamic = 0;
171 unsigned count = static_cast<unsigned>(staticValues.size());
172 for (unsigned idx = 0; idx < count; ++idx) {
173 int64_t value = staticValues[idx];
174 res.push_back(ShapedType::isDynamic(value)
175 ? OpFoldResult{dynamicValues[numDynamic++]}
176 : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
177 }
178 return res;
179}
180
181/// Decompose a vector of mixed static or dynamic values into the corresponding
182/// pair of arrays. This is the inverse function of `getMixedValues`.
183std::pair<ArrayAttr, SmallVector<Value>>
184decomposeMixedValues(Builder &b,
185 const SmallVectorImpl<OpFoldResult> &mixedValues) {
186 SmallVector<int64_t> staticValues;
187 SmallVector<Value> dynamicValues;
188 for (const auto &it : mixedValues) {
189 if (it.is<Attribute>()) {
190 staticValues.push_back(Elt: cast<IntegerAttr>(it.get<Attribute>()).getInt());
191 } else {
192 staticValues.push_back(ShapedType::kDynamic);
193 dynamicValues.push_back(Elt: it.get<Value>());
194 }
195 }
196 return {b.getI64ArrayAttr(staticValues), dynamicValues};
197}
198
199/// Helper to sort `values` according to matching `keys`.
200template <typename K, typename V>
201static SmallVector<V>
202getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values,
203 llvm::function_ref<bool(K, K)> compare) {
204 if (keys.empty())
205 return SmallVector<V>{values};
206 assert(keys.size() == values.size() && "unexpected mismatching sizes");
207 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
208 std::sort(indices.begin(), indices.end(),
209 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
210 SmallVector<V> res;
211 res.reserve(values.size());
212 for (int64_t i = 0, e = indices.size(); i < e; ++i)
213 res.push_back(values[indices[i]]);
214 return res;
215}
216
217SmallVector<Value>
218getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
219 llvm::function_ref<bool(Attribute, Attribute)> compare) {
220 return getValuesSortedByKeyImpl(keys, values, compare);
221}
222
223SmallVector<OpFoldResult>
224getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
225 llvm::function_ref<bool(Attribute, Attribute)> compare) {
226 return getValuesSortedByKeyImpl(keys, values, compare);
227}
228
229SmallVector<int64_t>
230getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
231 llvm::function_ref<bool(Attribute, Attribute)> compare) {
232 return getValuesSortedByKeyImpl(keys, values, compare);
233}
234
235/// Return the number of iterations for a loop with a lower bound `lb`, upper
236/// bound `ub` and step `step`.
237std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
238 OpFoldResult step) {
239 if (lb == ub)
240 return 0;
241
242 std::optional<int64_t> lbConstant = getConstantIntValue(ofr: lb);
243 if (!lbConstant)
244 return std::nullopt;
245 std::optional<int64_t> ubConstant = getConstantIntValue(ofr: ub);
246 if (!ubConstant)
247 return std::nullopt;
248 std::optional<int64_t> stepConstant = getConstantIntValue(ofr: step);
249 if (!stepConstant)
250 return std::nullopt;
251
252 return mlir::ceilDiv(lhs: *ubConstant - *lbConstant, rhs: *stepConstant);
253}
254
255bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
256 return llvm::none_of(Range&: sizesOrOffsets, P: [](int64_t value) {
257 return !ShapedType::isDynamic(value) && value < 0;
258 });
259}
260
261bool hasValidStrides(SmallVector<int64_t> strides) {
262 return llvm::none_of(Range&: strides, P: [](int64_t value) {
263 return !ShapedType::isDynamic(value) && value == 0;
264 });
265}
266
267LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
268 bool onlyNonNegative, bool onlyNonZero) {
269 bool valuesChanged = false;
270 for (OpFoldResult &ofr : ofrs) {
271 if (ofr.is<Attribute>())
272 continue;
273 Attribute attr;
274 if (matchPattern(value: ofr.get<Value>(), pattern: m_Constant(bind_value: &attr))) {
275 // Note: All ofrs have index type.
276 if (onlyNonNegative && *getConstantIntValue(ofr: attr) < 0)
277 continue;
278 if (onlyNonZero && *getConstantIntValue(ofr: attr) == 0)
279 continue;
280 ofr = attr;
281 valuesChanged = true;
282 }
283 }
284 return success(isSuccess: valuesChanged);
285}
286
287LogicalResult
288foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
289 return foldDynamicIndexList(ofrs&: offsetsOrSizes, /*onlyNonNegative=*/true,
290 /*onlyNonZero=*/false);
291}
292
293LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
294 return foldDynamicIndexList(ofrs&: strides, /*onlyNonNegative=*/false,
295 /*onlyNonZero=*/true);
296}
297
298} // namespace mlir
299

source code of mlir/lib/Dialect/Utils/StaticValueUtils.cpp