1 | //===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header file defines utilities for dealing with static values, e.g., |
10 | // converting back and forth between Value and OpFoldResult. Such functionality |
11 | // is used in multiple dialects. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H |
16 | #define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H |
17 | |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/OpDefinition.h" |
21 | #include "mlir/Support/LLVM.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | #include "llvm/ADT/SmallVectorExtras.h" |
24 | |
25 | namespace mlir { |
26 | |
27 | /// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp |
28 | /// with attribute with value `0`. |
29 | bool isZeroIndex(OpFoldResult v); |
30 | |
31 | /// Represents a range (offset, size, and stride) where each element of the |
32 | /// triple may be dynamic or static. |
33 | struct Range { |
34 | OpFoldResult offset; |
35 | OpFoldResult size; |
36 | OpFoldResult stride; |
37 | }; |
38 | |
39 | /// Given an array of Range values, return a tuple of (offset vector, sizes |
40 | /// vector, and strides vector) formed by separating out the individual |
41 | /// elements of each range. |
42 | std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, |
43 | SmallVector<OpFoldResult>> |
44 | getOffsetsSizesAndStrides(ArrayRef<Range> ranges); |
45 | |
46 | /// Helper function to dispatch an OpFoldResult into `staticVec` if: |
47 | /// a) it is an IntegerAttr |
48 | /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. |
49 | /// In such dynamic cases, ShapedType::kDynamic is also pushed to |
50 | /// `staticVec`. This is useful to extract mixed static and dynamic entries |
51 | /// that come from an AttrSizedOperandSegments trait. |
52 | void dispatchIndexOpFoldResult(OpFoldResult ofr, |
53 | SmallVectorImpl<Value> &dynamicVec, |
54 | SmallVectorImpl<int64_t> &staticVec); |
55 | |
56 | /// Helper function to dispatch multiple OpFoldResults according to the |
57 | /// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single |
58 | /// OpFoldResult. |
59 | void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, |
60 | SmallVectorImpl<Value> &dynamicVec, |
61 | SmallVectorImpl<int64_t> &staticVec); |
62 | |
63 | /// Extract integer values from the assumed ArrayAttr of IntegerAttr. |
64 | template <typename IntTy> |
65 | SmallVector<IntTy> (Attribute attr) { |
66 | return llvm::to_vector( |
67 | llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy { |
68 | return cast<IntegerAttr>(a).getInt(); |
69 | })); |
70 | } |
71 | |
72 | /// Given a value, try to extract a constant Attribute. If this fails, return |
73 | /// the original value. |
74 | OpFoldResult getAsOpFoldResult(Value val); |
75 | /// Given an array of values, try to extract a constant Attribute from each |
76 | /// value. If this fails, return the original value. |
77 | SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values); |
78 | /// Convert `arrayAttr` to a vector of OpFoldResult. |
79 | SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr); |
80 | |
81 | /// Convert int64_t to integer attributes of index type and return them as |
82 | /// OpFoldResult. |
83 | OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val); |
84 | SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx, |
85 | ArrayRef<int64_t> values); |
86 | |
87 | /// If ofr is a constant integer or an IntegerAttr, return the integer. |
88 | std::optional<int64_t> getConstantIntValue(OpFoldResult ofr); |
89 | /// If all ofrs are constant integers or IntegerAttrs, return the integers. |
90 | std::optional<SmallVector<int64_t>> |
91 | getConstantIntValues(ArrayRef<OpFoldResult> ofrs); |
92 | |
93 | /// Return true if `ofr` is constant integer equal to `value`. |
94 | bool isConstantIntValue(OpFoldResult ofr, int64_t value); |
95 | |
96 | /// Return true if ofr1 and ofr2 are the same integer constant attribute |
97 | /// values or the same SSA value. Ignore integer bitwitdh and type mismatch |
98 | /// that come from the fact there is no IndexAttr and that IndexType have no |
99 | /// bitwidth. |
100 | bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); |
101 | bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1, |
102 | ArrayRef<OpFoldResult> ofrs2); |
103 | |
104 | // To convert an OpFoldResult to a Value of index type, see: |
105 | // mlir/include/mlir/Dialect/Arith/Utils/Utils.h |
106 | // TODO: find a better common landing place. |
107 | // |
108 | // Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, |
109 | // OpFoldResult ofr); |
110 | |
111 | // To convert an OpFoldResult to a Value of index type, see: |
112 | // mlir/include/mlir/Dialect/Arith/Utils/Utils.h |
113 | // TODO: find a better common landing place. |
114 | // |
115 | // SmallVector<Value> |
116 | // getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, |
117 | // ArrayRef<OpFoldResult> valueOrAttrVec); |
118 | |
119 | /// Return a vector of OpFoldResults with the same size a staticValues, but |
120 | /// all elements for which ShapedType::isDynamic is true, will be replaced by |
121 | /// dynamicValues. |
122 | SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues, |
123 | ValueRange dynamicValues, Builder &b); |
124 | |
125 | /// Decompose a vector of mixed static or dynamic values into the |
126 | /// corresponding pair of arrays. This is the inverse function of |
127 | /// `getMixedValues`. |
128 | std::pair<ArrayAttr, SmallVector<Value>> |
129 | decomposeMixedValues(Builder &b, |
130 | const SmallVectorImpl<OpFoldResult> &mixedValues); |
131 | |
132 | /// Helper to sort `values` according to matching `keys`. |
133 | SmallVector<Value> |
134 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values, |
135 | llvm::function_ref<bool(Attribute, Attribute)> compare); |
136 | SmallVector<OpFoldResult> |
137 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values, |
138 | llvm::function_ref<bool(Attribute, Attribute)> compare); |
139 | SmallVector<int64_t> |
140 | getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values, |
141 | llvm::function_ref<bool(Attribute, Attribute)> compare); |
142 | |
143 | /// Helper function to check whether the passed in `sizes` or `offsets` are |
144 | /// valid. This can be used to re-check whether dimensions are still valid |
145 | /// after constant folding the dynamic dimensions. |
146 | bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets); |
147 | |
148 | /// Helper function to check whether the passed in `strides` are valid. This |
149 | /// can be used to re-check whether dimensions are still valid after constant |
150 | /// folding the dynamic dimensions. |
151 | bool hasValidStrides(SmallVector<int64_t> strides); |
152 | |
153 | /// Returns "success" when any of the elements in `ofrs` is a constant value. In |
154 | /// that case the value is replaced by an attribute. Returns "failure" when no |
155 | /// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only |
156 | /// non-negative and non-zero constant values are folded respectively. |
157 | LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs, |
158 | bool onlyNonNegative = false, |
159 | bool onlyNonZero = false); |
160 | |
161 | /// Returns "success" when any of the elements in `offsetsOrSizes` is a |
162 | /// constant value. In that case the value is replaced by an attribute. Returns |
163 | /// "failure" when no folding happened. Invalid values are not folded to avoid |
164 | /// canonicalization crashes. |
165 | LogicalResult |
166 | foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes); |
167 | |
168 | /// Returns "success" when any of the elements in `strides` is a constant |
169 | /// value. In that case the value is replaced by an attribute. Returns |
170 | /// "failure" when no folding happened. Invalid values are not folded to avoid |
171 | /// canonicalization crashes. |
172 | LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides); |
173 | |
174 | /// Return the number of iterations for a loop with a lower bound `lb`, upper |
175 | /// bound `ub` and step `step`. |
176 | std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, |
177 | OpFoldResult step); |
178 | |
179 | /// Idiomatic saturated operations on values like offsets, sizes, and strides. |
180 | struct SaturatedInteger { |
181 | static SaturatedInteger wrap(int64_t v) { |
182 | return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0} |
183 | : SaturatedInteger{false, v}; |
184 | } |
185 | int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; } |
186 | FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) { |
187 | if (saturated && !other.saturated) |
188 | return other; |
189 | if (!saturated && !other.saturated && v != other.v) |
190 | return failure(); |
191 | return *this; |
192 | } |
193 | bool operator==(SaturatedInteger other) { |
194 | return (saturated && other.saturated) || |
195 | (!saturated && !other.saturated && v == other.v); |
196 | } |
197 | bool operator!=(SaturatedInteger other) { return !(*this == other); } |
198 | SaturatedInteger operator+(SaturatedInteger other) { |
199 | if (saturated || other.saturated) |
200 | return SaturatedInteger{.saturated: true, .v: 0}; |
201 | return SaturatedInteger{.saturated: false, .v: other.v + v}; |
202 | } |
203 | SaturatedInteger operator*(SaturatedInteger other) { |
204 | if (saturated || other.saturated) |
205 | return SaturatedInteger{.saturated: true, .v: 0}; |
206 | return SaturatedInteger{.saturated: false, .v: other.v * v}; |
207 | } |
208 | bool saturated = true; |
209 | int64_t v = 0; |
210 | }; |
211 | |
212 | } // namespace mlir |
213 | |
214 | #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H |
215 | |