1//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/StaticValueUtils.h"
10#include "mlir/IR/Matchers.h"
11#include "mlir/Support/LLVM.h"
12#include "llvm/ADT/APSInt.h"
13#include "llvm/ADT/STLExtras.h"
14#include "llvm/Support/MathExtras.h"
15
16namespace mlir {
17
18bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(ofr: v, value: 0); }
19
20bool isOneInteger(OpFoldResult v) { return isConstantIntValue(ofr: v, value: 1); }
21
22std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
23 SmallVector<OpFoldResult>>
24getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
25 SmallVector<OpFoldResult> offsets, sizes, strides;
26 offsets.reserve(N: ranges.size());
27 sizes.reserve(N: ranges.size());
28 strides.reserve(N: ranges.size());
29 for (const auto &[offset, size, stride] : ranges) {
30 offsets.push_back(Elt: offset);
31 sizes.push_back(Elt: size);
32 strides.push_back(Elt: stride);
33 }
34 return std::make_tuple(args&: offsets, args&: sizes, args&: strides);
35}
36
37/// Helper function to dispatch an OpFoldResult into `staticVec` if:
38/// a) it is an IntegerAttr
39/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
40/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
41/// `staticVec`. This is useful to extract mixed static and dynamic entries that
42/// come from an AttrSizedOperandSegments trait.
43void dispatchIndexOpFoldResult(OpFoldResult ofr,
44 SmallVectorImpl<Value> &dynamicVec,
45 SmallVectorImpl<int64_t> &staticVec) {
46 auto v = llvm::dyn_cast_if_present<Value>(Val&: ofr);
47 if (!v) {
48 APInt apInt = cast<IntegerAttr>(cast<Attribute>(Val&: ofr)).getValue();
49 staticVec.push_back(Elt: apInt.getSExtValue());
50 return;
51 }
52 dynamicVec.push_back(Elt: v);
53 staticVec.push_back(ShapedType::kDynamic);
54}
55
56std::pair<int64_t, OpFoldResult>
57getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) {
58 int64_t tileSizeForShape =
59 getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
60
61 OpFoldResult tileSizeOfrSimplified =
62 (tileSizeForShape != ShapedType::kDynamic)
63 ? b.getIndexAttr(tileSizeForShape)
64 : tileSizeOfr;
65
66 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
67 tileSizeOfrSimplified);
68}
69
70void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
71 SmallVectorImpl<Value> &dynamicVec,
72 SmallVectorImpl<int64_t> &staticVec) {
73 for (OpFoldResult ofr : ofrs)
74 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
75}
76
77/// Given a value, try to extract a constant Attribute. If this fails, return
78/// the original value.
79OpFoldResult getAsOpFoldResult(Value val) {
80 if (!val)
81 return OpFoldResult();
82 Attribute attr;
83 if (matchPattern(value: val, pattern: m_Constant(bind_value: &attr)))
84 return attr;
85 return val;
86}
87
88/// Given an array of values, try to extract a constant Attribute from each
89/// value. If this fails, return the original value.
90SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) {
91 return llvm::to_vector(
92 Range: llvm::map_range(C&: values, F: [](Value v) { return getAsOpFoldResult(val: v); }));
93}
94
95/// Convert `arrayAttr` to a vector of OpFoldResult.
96SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
97 SmallVector<OpFoldResult> res;
98 res.reserve(N: arrayAttr.size());
99 for (Attribute a : arrayAttr)
100 res.push_back(a);
101 return res;
102}
103
104OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) {
105 return IntegerAttr::get(IndexType::get(ctx), val);
106}
107
108SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
109 ArrayRef<int64_t> values) {
110 return llvm::to_vector(Range: llvm::map_range(
111 C&: values, F: [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, val: v); }));
112}
113
114/// If ofr is a constant integer or an IntegerAttr, return the integer.
115std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
116 // Case 1: Check for Constant integer.
117 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) {
118 APSInt intVal;
119 if (matchPattern(val, m_ConstantInt(&intVal)))
120 return intVal.getSExtValue();
121 return std::nullopt;
122 }
123 // Case 2: Check for IntegerAttr.
124 Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: ofr);
125 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
126 return intAttr.getValue().getSExtValue();
127 return std::nullopt;
128}
129
130std::optional<SmallVector<int64_t>>
131getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
132 bool failed = false;
133 SmallVector<int64_t> res = llvm::map_to_vector(C&: ofrs, F: [&](OpFoldResult ofr) {
134 auto cv = getConstantIntValue(ofr);
135 if (!cv.has_value())
136 failed = true;
137 return cv.value_or(u: 0);
138 });
139 if (failed)
140 return std::nullopt;
141 return res;
142}
143
144bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
145 auto val = getConstantIntValue(ofr);
146 return val && *val == value;
147}
148
149bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
150 return llvm::all_of(
151 Range&: ofrs, P: [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
152}
153
154bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
155 ArrayRef<int64_t> values) {
156 if (ofrs.size() != values.size())
157 return false;
158 std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
159 return constOfrs && llvm::equal(LRange&: constOfrs.value(), RRange&: values);
160}
161
162/// Return true if ofr1 and ofr2 are the same integer constant attribute values
163/// or the same SSA value.
164/// Ignore integer bitwidth and type mismatch that come from the fact there is
165/// no IndexAttr and that IndexType has no bitwidth.
166bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
167 auto cst1 = getConstantIntValue(ofr: ofr1), cst2 = getConstantIntValue(ofr: ofr2);
168 if (cst1 && cst2 && *cst1 == *cst2)
169 return true;
170 auto v1 = llvm::dyn_cast_if_present<Value>(Val&: ofr1),
171 v2 = llvm::dyn_cast_if_present<Value>(Val&: ofr2);
172 return v1 && v1 == v2;
173}
174
175bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
176 ArrayRef<OpFoldResult> ofrs2) {
177 if (ofrs1.size() != ofrs2.size())
178 return false;
179 for (auto [ofr1, ofr2] : llvm::zip_equal(t&: ofrs1, u&: ofrs2))
180 if (!isEqualConstantIntOrValue(ofr1, ofr2))
181 return false;
182 return true;
183}
184
185/// Return a vector of OpFoldResults with the same size a staticValues, but all
186/// elements for which ShapedType::isDynamic is true, will be replaced by
187/// dynamicValues.
188SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
189 ValueRange dynamicValues,
190 MLIRContext *context) {
191 SmallVector<OpFoldResult> res;
192 res.reserve(N: staticValues.size());
193 unsigned numDynamic = 0;
194 unsigned count = static_cast<unsigned>(staticValues.size());
195 for (unsigned idx = 0; idx < count; ++idx) {
196 int64_t value = staticValues[idx];
197 res.push_back(ShapedType::isDynamic(value)
198 ? OpFoldResult{dynamicValues[numDynamic++]}
199 : OpFoldResult{IntegerAttr::get(
200 IntegerType::get(context, 64), staticValues[idx])});
201 }
202 return res;
203}
204SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
205 ValueRange dynamicValues, Builder &b) {
206 return getMixedValues(staticValues, dynamicValues, context: b.getContext());
207}
208
209/// Decompose a vector of mixed static or dynamic values into the corresponding
210/// pair of arrays. This is the inverse function of `getMixedValues`.
211std::pair<SmallVector<int64_t>, SmallVector<Value>>
212decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
213 SmallVector<int64_t> staticValues;
214 SmallVector<Value> dynamicValues;
215 for (const auto &it : mixedValues) {
216 if (auto attr = dyn_cast<Attribute>(Val: it)) {
217 staticValues.push_back(Elt: cast<IntegerAttr>(attr).getInt());
218 } else {
219 staticValues.push_back(ShapedType::kDynamic);
220 dynamicValues.push_back(Elt: cast<Value>(Val: it));
221 }
222 }
223 return {staticValues, dynamicValues};
224}
225
226/// Helper to sort `values` according to matching `keys`.
227template <typename K, typename V>
228static SmallVector<V>
229getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values,
230 llvm::function_ref<bool(K, K)> compare) {
231 if (keys.empty())
232 return SmallVector<V>{values};
233 assert(keys.size() == values.size() && "unexpected mismatching sizes");
234 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
235 llvm::sort(indices,
236 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
237 SmallVector<V> res;
238 res.reserve(values.size());
239 for (int64_t i = 0, e = indices.size(); i < e; ++i)
240 res.push_back(values[indices[i]]);
241 return res;
242}
243
244SmallVector<Value>
245getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
246 llvm::function_ref<bool(Attribute, Attribute)> compare) {
247 return getValuesSortedByKeyImpl(keys, values, compare);
248}
249
250SmallVector<OpFoldResult>
251getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
252 llvm::function_ref<bool(Attribute, Attribute)> compare) {
253 return getValuesSortedByKeyImpl(keys, values, compare);
254}
255
256SmallVector<int64_t>
257getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
258 llvm::function_ref<bool(Attribute, Attribute)> compare) {
259 return getValuesSortedByKeyImpl(keys, values, compare);
260}
261
262/// Return the number of iterations for a loop with a lower bound `lb`, upper
263/// bound `ub` and step `step`.
264std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
265 OpFoldResult step) {
266 if (lb == ub)
267 return 0;
268
269 std::optional<int64_t> lbConstant = getConstantIntValue(ofr: lb);
270 if (!lbConstant)
271 return std::nullopt;
272 std::optional<int64_t> ubConstant = getConstantIntValue(ofr: ub);
273 if (!ubConstant)
274 return std::nullopt;
275 std::optional<int64_t> stepConstant = getConstantIntValue(ofr: step);
276 if (!stepConstant)
277 return std::nullopt;
278
279 return llvm::divideCeilSigned(Numerator: *ubConstant - *lbConstant, Denominator: *stepConstant);
280}
281
282bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
283 return llvm::none_of(Range&: sizesOrOffsets, P: [](int64_t value) {
284 return !ShapedType::isDynamic(value) && value < 0;
285 });
286}
287
288bool hasValidStrides(SmallVector<int64_t> strides) {
289 return llvm::none_of(Range&: strides, P: [](int64_t value) {
290 return !ShapedType::isDynamic(value) && value == 0;
291 });
292}
293
294LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
295 bool onlyNonNegative, bool onlyNonZero) {
296 bool valuesChanged = false;
297 for (OpFoldResult &ofr : ofrs) {
298 if (isa<Attribute>(Val: ofr))
299 continue;
300 Attribute attr;
301 if (matchPattern(value: cast<Value>(Val&: ofr), pattern: m_Constant(bind_value: &attr))) {
302 // Note: All ofrs have index type.
303 if (onlyNonNegative && *getConstantIntValue(ofr: attr) < 0)
304 continue;
305 if (onlyNonZero && *getConstantIntValue(ofr: attr) == 0)
306 continue;
307 ofr = attr;
308 valuesChanged = true;
309 }
310 }
311 return success(IsSuccess: valuesChanged);
312}
313
314LogicalResult
315foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
316 return foldDynamicIndexList(ofrs&: offsetsOrSizes, /*onlyNonNegative=*/true,
317 /*onlyNonZero=*/false);
318}
319
320LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
321 return foldDynamicIndexList(ofrs&: strides, /*onlyNonNegative=*/false,
322 /*onlyNonZero=*/true);
323}
324
325} // namespace mlir
326

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Utils/StaticValueUtils.cpp