1 | //===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header file defines utilities for dealing with static values, e.g., |
10 | // converting back and forth between Value and OpFoldResult. Such functionality |
11 | // is used in multiple dialects. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H |
16 | #define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H |
17 | |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/OpDefinition.h" |
21 | #include "mlir/Support/LLVM.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | #include "llvm/ADT/SmallVectorExtras.h" |
24 | |
25 | namespace mlir { |
26 | |
27 | /// Return true if `v` is an IntegerAttr with value `0`. |
28 | bool isZeroInteger(OpFoldResult v); |
29 | |
30 | /// Return true if `v` is an IntegerAttr with value `1`. |
31 | bool isOneInteger(OpFoldResult v); |
32 | |
33 | /// Represents a range (offset, size, and stride) where each element of the |
34 | /// triple may be dynamic or static. |
35 | struct Range { |
36 | OpFoldResult offset; |
37 | OpFoldResult size; |
38 | OpFoldResult stride; |
39 | }; |
40 | |
41 | /// Given an array of Range values, return a tuple of (offset vector, sizes |
42 | /// vector, and strides vector) formed by separating out the individual |
43 | /// elements of each range. |
44 | std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, |
45 | SmallVector<OpFoldResult>> |
46 | getOffsetsSizesAndStrides(ArrayRef<Range> ranges); |
47 | |
48 | /// Helper function to dispatch an OpFoldResult into `staticVec` if: |
49 | /// a) it is an IntegerAttr |
50 | /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. |
51 | /// In such dynamic cases, ShapedType::kDynamic is also pushed to |
52 | /// `staticVec`. This is useful to extract mixed static and dynamic entries |
53 | /// that come from an AttrSizedOperandSegments trait. |
54 | void dispatchIndexOpFoldResult(OpFoldResult ofr, |
55 | SmallVectorImpl<Value> &dynamicVec, |
56 | SmallVectorImpl<int64_t> &staticVec); |
57 | |
58 | /// Helper function to dispatch multiple OpFoldResults according to the |
59 | /// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single |
60 | /// OpFoldResult. |
61 | void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, |
62 | SmallVectorImpl<Value> &dynamicVec, |
63 | SmallVectorImpl<int64_t> &staticVec); |
64 | |
65 | /// Given OpFoldResult representing dim size value (*), generates a pair of |
66 | /// sizes: |
67 | /// * 1st result, static value, contains an int64_t dim size that can be used |
68 | /// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims), |
69 | /// * 2nd result, dynamic value, contains OpFoldResult encapsulating the |
70 | /// actual dim size (either original or updated input value). |
71 | /// For input sizes for which it is possible to extract a constant Attribute, |
72 | /// replaces the original size value with an integer attribute (unless it's |
73 | /// already a constant Attribute). The 1st return value also becomes the actual |
74 | /// integer size (as opposed ShapedType::kDynamic). |
75 | /// |
76 | /// (*) This hook is usually used when, given input sizes as OpFoldResult, |
77 | /// it's required to generate two vectors: |
78 | /// * sizes as int64_t to generate a shape, |
79 | /// * sizes as OpFoldResult for sizes-like attribute. |
80 | /// Please update this comment if you identify other use cases. |
81 | std::pair<int64_t, OpFoldResult> |
82 | getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b); |
83 | |
84 | /// Extract integer values from the assumed ArrayAttr of IntegerAttr. |
85 | template <typename IntTy> |
86 | SmallVector<IntTy> (Attribute attr) { |
87 | return llvm::to_vector( |
88 | llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy { |
89 | return cast<IntegerAttr>(a).getInt(); |
90 | })); |
91 | } |
92 | |
93 | /// Given a value, try to extract a constant Attribute. If this fails, return |
94 | /// the original value. |
95 | OpFoldResult getAsOpFoldResult(Value val); |
96 | /// Given an array of values, try to extract a constant Attribute from each |
97 | /// value. If this fails, return the original value. |
98 | SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values); |
99 | /// Convert `arrayAttr` to a vector of OpFoldResult. |
100 | SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr); |
101 | |
102 | /// Convert int64_t to integer attributes of index type and return them as |
103 | /// OpFoldResult. |
104 | OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val); |
105 | SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx, |
106 | ArrayRef<int64_t> values); |
107 | |
108 | /// If ofr is a constant integer or an IntegerAttr, return the integer. |
109 | std::optional<int64_t> getConstantIntValue(OpFoldResult ofr); |
110 | /// If all ofrs are constant integers or IntegerAttrs, return the integers. |
111 | std::optional<SmallVector<int64_t>> |
112 | getConstantIntValues(ArrayRef<OpFoldResult> ofrs); |
113 | |
114 | /// Return true if `ofr` is constant integer equal to `value`. |
115 | bool isConstantIntValue(OpFoldResult ofr, int64_t value); |
116 | /// Return true if all of `ofrs` are constant integers equal to `value`. |
117 | bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value); |
118 | /// Return true if all of `ofrs` are constant integers equal to the |
119 | /// corresponding value in `values`. |
120 | bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs, |
121 | ArrayRef<int64_t> values); |
122 | |
123 | /// Return true if ofr1 and ofr2 are the same integer constant attribute |
124 | /// values or the same SSA value. Ignore integer bitwitdh and type mismatch |
125 | /// that come from the fact there is no IndexAttr and that IndexType have no |
126 | /// bitwidth. |
127 | bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); |
128 | bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1, |
129 | ArrayRef<OpFoldResult> ofrs2); |
130 | |
131 | // To convert an OpFoldResult to a Value of index type, see: |
132 | // mlir/include/mlir/Dialect/Arith/Utils/Utils.h |
133 | // TODO: find a better common landing place. |
134 | // |
135 | // Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, |
136 | // OpFoldResult ofr); |
137 | |
138 | // To convert an OpFoldResult to a Value of index type, see: |
139 | // mlir/include/mlir/Dialect/Arith/Utils/Utils.h |
140 | // TODO: find a better common landing place. |
141 | // |
142 | // SmallVector<Value> |
143 | // getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, |
144 | // ArrayRef<OpFoldResult> valueOrAttrVec); |
145 | |
146 | /// Return a vector of OpFoldResults with the same size a staticValues, but |
147 | /// all elements for which ShapedType::isDynamic is true, will be replaced by |
148 | /// dynamicValues. |
149 | SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, |
150 | ValueRange dynamicValues, |
151 | MLIRContext *context); |
152 | SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, |
153 | ValueRange dynamicValues, Builder &b); |
154 | |
155 | /// Decompose a vector of mixed static or dynamic values into the |
156 | /// corresponding pair of arrays. This is the inverse function of |
157 | /// `getMixedValues`. |
158 | std::pair<SmallVector<int64_t>, SmallVector<Value>> |
159 | decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues); |
160 | |
161 | /// Helper to sort `values` according to matching `keys`. |
162 | SmallVector<Value> |
163 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values, |
164 | llvm::function_ref<bool(Attribute, Attribute)> compare); |
165 | SmallVector<OpFoldResult> |
166 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values, |
167 | llvm::function_ref<bool(Attribute, Attribute)> compare); |
168 | SmallVector<int64_t> |
169 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values, |
170 | llvm::function_ref<bool(Attribute, Attribute)> compare); |
171 | |
172 | /// Helper function to check whether the passed in `sizes` or `offsets` are |
173 | /// valid. This can be used to re-check whether dimensions are still valid |
174 | /// after constant folding the dynamic dimensions. |
175 | bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets); |
176 | |
177 | /// Helper function to check whether the passed in `strides` are valid. This |
178 | /// can be used to re-check whether dimensions are still valid after constant |
179 | /// folding the dynamic dimensions. |
180 | bool hasValidStrides(SmallVector<int64_t> strides); |
181 | |
182 | /// Returns "success" when any of the elements in `ofrs` is a constant value. In |
183 | /// that case the value is replaced by an attribute. Returns "failure" when no |
184 | /// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only |
185 | /// non-negative and non-zero constant values are folded respectively. |
186 | LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, |
187 | bool onlyNonNegative = false, |
188 | bool onlyNonZero = false); |
189 | |
190 | /// Returns "success" when any of the elements in `offsetsOrSizes` is a |
191 | /// constant value. In that case the value is replaced by an attribute. Returns |
192 | /// "failure" when no folding happened. Invalid values are not folded to avoid |
193 | /// canonicalization crashes. |
194 | LogicalResult |
195 | foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes); |
196 | |
197 | /// Returns "success" when any of the elements in `strides` is a constant |
198 | /// value. In that case the value is replaced by an attribute. Returns |
199 | /// "failure" when no folding happened. Invalid values are not folded to avoid |
200 | /// canonicalization crashes. |
201 | LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides); |
202 | |
203 | /// Return the number of iterations for a loop with a lower bound `lb`, upper |
204 | /// bound `ub` and step `step`. |
205 | std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, |
206 | OpFoldResult step); |
207 | |
208 | /// Idiomatic saturated operations on values like offsets, sizes, and strides. |
209 | struct SaturatedInteger { |
210 | static SaturatedInteger wrap(int64_t v) { |
211 | return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0} |
212 | : SaturatedInteger{false, v}; |
213 | } |
214 | int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; } |
215 | FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) { |
216 | if (saturated && !other.saturated) |
217 | return other; |
218 | if (!saturated && !other.saturated && v != other.v) |
219 | return failure(); |
220 | return *this; |
221 | } |
222 | bool operator==(SaturatedInteger other) { |
223 | return (saturated && other.saturated) || |
224 | (!saturated && !other.saturated && v == other.v); |
225 | } |
226 | bool operator!=(SaturatedInteger other) { return !(*this == other); } |
227 | SaturatedInteger operator+(SaturatedInteger other) { |
228 | if (saturated || other.saturated) |
229 | return SaturatedInteger{.saturated: true, .v: 0}; |
230 | return SaturatedInteger{.saturated: false, .v: other.v + v}; |
231 | } |
232 | SaturatedInteger operator*(SaturatedInteger other) { |
233 | // Multiplication with 0 is always 0. |
234 | if (!other.saturated && other.v == 0) |
235 | return SaturatedInteger{.saturated: false, .v: 0}; |
236 | if (!saturated && v == 0) |
237 | return SaturatedInteger{.saturated: false, .v: 0}; |
238 | // Otherwise, if this or the other integer is dynamic, so is the result. |
239 | if (saturated || other.saturated) |
240 | return SaturatedInteger{.saturated: true, .v: 0}; |
241 | return SaturatedInteger{.saturated: false, .v: other.v * v}; |
242 | } |
243 | bool saturated = true; |
244 | int64_t v = 0; |
245 | }; |
246 | |
247 | } // namespace mlir |
248 | |
249 | #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H |
250 | |