1//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines utilities for dealing with static values, e.g.,
10// converting back and forth between Value and OpFoldResult. Such functionality
11// is used in multiple dialects.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
16#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
17
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/OpDefinition.h"
21#include "mlir/Support/LLVM.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24
25namespace mlir {
26
27/// Return true if `v` is an IntegerAttr with value `0`.
28bool isZeroInteger(OpFoldResult v);
29
30/// Return true if `v` is an IntegerAttr with value `1`.
31bool isOneInteger(OpFoldResult v);
32
33/// Represents a range (offset, size, and stride) where each element of the
34/// triple may be dynamic or static.
35struct Range {
36 OpFoldResult offset;
37 OpFoldResult size;
38 OpFoldResult stride;
39};
40
41/// Given an array of Range values, return a tuple of (offset vector, sizes
42/// vector, and strides vector) formed by separating out the individual
43/// elements of each range.
44std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
45 SmallVector<OpFoldResult>>
46getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
47
48/// Helper function to dispatch an OpFoldResult into `staticVec` if:
49/// a) it is an IntegerAttr
50/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
51/// In such dynamic cases, ShapedType::kDynamic is also pushed to
52/// `staticVec`. This is useful to extract mixed static and dynamic entries
53/// that come from an AttrSizedOperandSegments trait.
54void dispatchIndexOpFoldResult(OpFoldResult ofr,
55 SmallVectorImpl<Value> &dynamicVec,
56 SmallVectorImpl<int64_t> &staticVec);
57
58/// Helper function to dispatch multiple OpFoldResults according to the
59/// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single
60/// OpFoldResult.
61void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
62 SmallVectorImpl<Value> &dynamicVec,
63 SmallVectorImpl<int64_t> &staticVec);
64
65/// Given OpFoldResult representing dim size value (*), generates a pair of
66/// sizes:
67/// * 1st result, static value, contains an int64_t dim size that can be used
68/// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
69/// * 2nd result, dynamic value, contains OpFoldResult encapsulating the
70/// actual dim size (either original or updated input value).
71/// For input sizes for which it is possible to extract a constant Attribute,
72/// replaces the original size value with an integer attribute (unless it's
73/// already a constant Attribute). The 1st return value also becomes the actual
74/// integer size (as opposed ShapedType::kDynamic).
75///
76/// (*) This hook is usually used when, given input sizes as OpFoldResult,
77/// it's required to generate two vectors:
78/// * sizes as int64_t to generate a shape,
79/// * sizes as OpFoldResult for sizes-like attribute.
80/// Please update this comment if you identify other use cases.
81std::pair<int64_t, OpFoldResult>
82getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b);
83
84/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
85template <typename IntTy>
86SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {
87 return llvm::to_vector(
88 llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy {
89 return cast<IntegerAttr>(a).getInt();
90 }));
91}
92
93/// Given a value, try to extract a constant Attribute. If this fails, return
94/// the original value.
95OpFoldResult getAsOpFoldResult(Value val);
96/// Given an array of values, try to extract a constant Attribute from each
97/// value. If this fails, return the original value.
98SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
99/// Convert `arrayAttr` to a vector of OpFoldResult.
100SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
101
102/// Convert int64_t to integer attributes of index type and return them as
103/// OpFoldResult.
104OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
105SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
106 ArrayRef<int64_t> values);
107
108/// If ofr is a constant integer or an IntegerAttr, return the integer.
109std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
110/// If all ofrs are constant integers or IntegerAttrs, return the integers.
111std::optional<SmallVector<int64_t>>
112getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
113
114/// Return true if `ofr` is constant integer equal to `value`.
115bool isConstantIntValue(OpFoldResult ofr, int64_t value);
116/// Return true if all of `ofrs` are constant integers equal to `value`.
117bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
118/// Return true if all of `ofrs` are constant integers equal to the
119/// corresponding value in `values`.
120bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
121 ArrayRef<int64_t> values);
122
123/// Return true if ofr1 and ofr2 are the same integer constant attribute
124/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
125/// that come from the fact there is no IndexAttr and that IndexType have no
126/// bitwidth.
127bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
128bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
129 ArrayRef<OpFoldResult> ofrs2);
130
131// To convert an OpFoldResult to a Value of index type, see:
132// mlir/include/mlir/Dialect/Arith/Utils/Utils.h
133// TODO: find a better common landing place.
134//
135// Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
136// OpFoldResult ofr);
137
138// To convert an OpFoldResult to a Value of index type, see:
139// mlir/include/mlir/Dialect/Arith/Utils/Utils.h
140// TODO: find a better common landing place.
141//
142// SmallVector<Value>
143// getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
144// ArrayRef<OpFoldResult> valueOrAttrVec);
145
146/// Return a vector of OpFoldResults with the same size a staticValues, but
147/// all elements for which ShapedType::isDynamic is true, will be replaced by
148/// dynamicValues.
149SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
150 ValueRange dynamicValues,
151 MLIRContext *context);
152SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
153 ValueRange dynamicValues, Builder &b);
154
155/// Decompose a vector of mixed static or dynamic values into the
156/// corresponding pair of arrays. This is the inverse function of
157/// `getMixedValues`.
158std::pair<SmallVector<int64_t>, SmallVector<Value>>
159decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
160
161/// Helper to sort `values` according to matching `keys`.
162SmallVector<Value>
163getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
164 llvm::function_ref<bool(Attribute, Attribute)> compare);
165SmallVector<OpFoldResult>
166getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
167 llvm::function_ref<bool(Attribute, Attribute)> compare);
168SmallVector<int64_t>
169getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
170 llvm::function_ref<bool(Attribute, Attribute)> compare);
171
172/// Helper function to check whether the passed in `sizes` or `offsets` are
173/// valid. This can be used to re-check whether dimensions are still valid
174/// after constant folding the dynamic dimensions.
175bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
176
177/// Helper function to check whether the passed in `strides` are valid. This
178/// can be used to re-check whether dimensions are still valid after constant
179/// folding the dynamic dimensions.
180bool hasValidStrides(SmallVector<int64_t> strides);
181
182/// Returns "success" when any of the elements in `ofrs` is a constant value. In
183/// that case the value is replaced by an attribute. Returns "failure" when no
184/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
185/// non-negative and non-zero constant values are folded respectively.
186LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
187 bool onlyNonNegative = false,
188 bool onlyNonZero = false);
189
190/// Returns "success" when any of the elements in `offsetsOrSizes` is a
191/// constant value. In that case the value is replaced by an attribute. Returns
192/// "failure" when no folding happened. Invalid values are not folded to avoid
193/// canonicalization crashes.
194LogicalResult
195foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
196
197/// Returns "success" when any of the elements in `strides` is a constant
198/// value. In that case the value is replaced by an attribute. Returns
199/// "failure" when no folding happened. Invalid values are not folded to avoid
200/// canonicalization crashes.
201LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
202
203/// Return the number of iterations for a loop with a lower bound `lb`, upper
204/// bound `ub` and step `step`.
205std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
206 OpFoldResult step);
207
208/// Idiomatic saturated operations on values like offsets, sizes, and strides.
209struct SaturatedInteger {
210 static SaturatedInteger wrap(int64_t v) {
211 return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
212 : SaturatedInteger{false, v};
213 }
214 int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
215 FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
216 if (saturated && !other.saturated)
217 return other;
218 if (!saturated && !other.saturated && v != other.v)
219 return failure();
220 return *this;
221 }
222 bool operator==(SaturatedInteger other) {
223 return (saturated && other.saturated) ||
224 (!saturated && !other.saturated && v == other.v);
225 }
226 bool operator!=(SaturatedInteger other) { return !(*this == other); }
227 SaturatedInteger operator+(SaturatedInteger other) {
228 if (saturated || other.saturated)
229 return SaturatedInteger{.saturated: true, .v: 0};
230 return SaturatedInteger{.saturated: false, .v: other.v + v};
231 }
232 SaturatedInteger operator*(SaturatedInteger other) {
233 // Multiplication with 0 is always 0.
234 if (!other.saturated && other.v == 0)
235 return SaturatedInteger{.saturated: false, .v: 0};
236 if (!saturated && v == 0)
237 return SaturatedInteger{.saturated: false, .v: 0};
238 // Otherwise, if this or the other integer is dynamic, so is the result.
239 if (saturated || other.saturated)
240 return SaturatedInteger{.saturated: true, .v: 0};
241 return SaturatedInteger{.saturated: false, .v: other.v * v};
242 }
243 bool saturated = true;
244 int64_t v = 0;
245};
246
247} // namespace mlir
248
249#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
250

source code of mlir/include/mlir/Dialect/Utils/StaticValueUtils.h