1//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines utilities for dealing with static values, e.g.,
10// converting back and forth between Value and OpFoldResult. Such functionality
11// is used in multiple dialects.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
16#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
17
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/OpDefinition.h"
21#include "mlir/Support/LLVM.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24
25namespace mlir {
26
27/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp
28/// with attribute with value `0`.
29bool isZeroIndex(OpFoldResult v);
30
31/// Represents a range (offset, size, and stride) where each element of the
32/// triple may be dynamic or static.
33struct Range {
34 OpFoldResult offset;
35 OpFoldResult size;
36 OpFoldResult stride;
37};
38
39/// Given an array of Range values, return a tuple of (offset vector, sizes
40/// vector, and strides vector) formed by separating out the individual
41/// elements of each range.
42std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
43 SmallVector<OpFoldResult>>
44getOffsetsSizesAndStrides(ArrayRef<Range> ranges);
45
46/// Helper function to dispatch an OpFoldResult into `staticVec` if:
47/// a) it is an IntegerAttr
48/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
49/// In such dynamic cases, ShapedType::kDynamic is also pushed to
50/// `staticVec`. This is useful to extract mixed static and dynamic entries
51/// that come from an AttrSizedOperandSegments trait.
52void dispatchIndexOpFoldResult(OpFoldResult ofr,
53 SmallVectorImpl<Value> &dynamicVec,
54 SmallVectorImpl<int64_t> &staticVec);
55
56/// Helper function to dispatch multiple OpFoldResults according to the
57/// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single
58/// OpFoldResult.
59void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
60 SmallVectorImpl<Value> &dynamicVec,
61 SmallVectorImpl<int64_t> &staticVec);
62
63/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
64template <typename IntTy>
65SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {
66 return llvm::to_vector(
67 llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy {
68 return cast<IntegerAttr>(a).getInt();
69 }));
70}
71
72/// Given a value, try to extract a constant Attribute. If this fails, return
73/// the original value.
74OpFoldResult getAsOpFoldResult(Value val);
75/// Given an array of values, try to extract a constant Attribute from each
76/// value. If this fails, return the original value.
77SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
78/// Convert `arrayAttr` to a vector of OpFoldResult.
79SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
80
81/// Convert int64_t to integer attributes of index type and return them as
82/// OpFoldResult.
83OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
84SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
85 ArrayRef<int64_t> values);
86
87/// If ofr is a constant integer or an IntegerAttr, return the integer.
88std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
89/// If all ofrs are constant integers or IntegerAttrs, return the integers.
90std::optional<SmallVector<int64_t>>
91getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
92
93/// Return true if `ofr` is constant integer equal to `value`.
94bool isConstantIntValue(OpFoldResult ofr, int64_t value);
95
96/// Return true if ofr1 and ofr2 are the same integer constant attribute
97/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
98/// that come from the fact there is no IndexAttr and that IndexType have no
99/// bitwidth.
100bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
101bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
102 ArrayRef<OpFoldResult> ofrs2);
103
104// To convert an OpFoldResult to a Value of index type, see:
105// mlir/include/mlir/Dialect/Arith/Utils/Utils.h
106// TODO: find a better common landing place.
107//
108// Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
109// OpFoldResult ofr);
110
111// To convert an OpFoldResult to a Value of index type, see:
112// mlir/include/mlir/Dialect/Arith/Utils/Utils.h
113// TODO: find a better common landing place.
114//
115// SmallVector<Value>
116// getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
117// ArrayRef<OpFoldResult> valueOrAttrVec);
118
119/// Return a vector of OpFoldResults with the same size a staticValues, but
120/// all elements for which ShapedType::isDynamic is true, will be replaced by
121/// dynamicValues.
122SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
123 ValueRange dynamicValues, Builder &b);
124
125/// Decompose a vector of mixed static or dynamic values into the
126/// corresponding pair of arrays. This is the inverse function of
127/// `getMixedValues`.
128std::pair<ArrayAttr, SmallVector<Value>>
129decomposeMixedValues(Builder &b,
130 const SmallVectorImpl<OpFoldResult> &mixedValues);
131
132/// Helper to sort `values` according to matching `keys`.
133SmallVector<Value>
134getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
135 llvm::function_ref<bool(Attribute, Attribute)> compare);
136SmallVector<OpFoldResult>
137getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
138 llvm::function_ref<bool(Attribute, Attribute)> compare);
139SmallVector<int64_t>
140getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
141 llvm::function_ref<bool(Attribute, Attribute)> compare);
142
143/// Helper function to check whether the passed in `sizes` or `offsets` are
144/// valid. This can be used to re-check whether dimensions are still valid
145/// after constant folding the dynamic dimensions.
146bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
147
148/// Helper function to check whether the passed in `strides` are valid. This
149/// can be used to re-check whether dimensions are still valid after constant
150/// folding the dynamic dimensions.
151bool hasValidStrides(SmallVector<int64_t> strides);
152
153/// Returns "success" when any of the elements in `ofrs` is a constant value. In
154/// that case the value is replaced by an attribute. Returns "failure" when no
155/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
156/// non-negative and non-zero constant values are folded respectively.
157LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
158 bool onlyNonNegative = false,
159 bool onlyNonZero = false);
160
161/// Returns "success" when any of the elements in `offsetsOrSizes` is a
162/// constant value. In that case the value is replaced by an attribute. Returns
163/// "failure" when no folding happened. Invalid values are not folded to avoid
164/// canonicalization crashes.
165LogicalResult
166foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
167
168/// Returns "success" when any of the elements in `strides` is a constant
169/// value. In that case the value is replaced by an attribute. Returns
170/// "failure" when no folding happened. Invalid values are not folded to avoid
171/// canonicalization crashes.
172LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
173
174/// Return the number of iterations for a loop with a lower bound `lb`, upper
175/// bound `ub` and step `step`.
176std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
177 OpFoldResult step);
178
179/// Idiomatic saturated operations on values like offsets, sizes, and strides.
180struct SaturatedInteger {
181 static SaturatedInteger wrap(int64_t v) {
182 return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
183 : SaturatedInteger{false, v};
184 }
185 int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
186 FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
187 if (saturated && !other.saturated)
188 return other;
189 if (!saturated && !other.saturated && v != other.v)
190 return failure();
191 return *this;
192 }
193 bool operator==(SaturatedInteger other) {
194 return (saturated && other.saturated) ||
195 (!saturated && !other.saturated && v == other.v);
196 }
197 bool operator!=(SaturatedInteger other) { return !(*this == other); }
198 SaturatedInteger operator+(SaturatedInteger other) {
199 if (saturated || other.saturated)
200 return SaturatedInteger{.saturated: true, .v: 0};
201 return SaturatedInteger{.saturated: false, .v: other.v + v};
202 }
203 SaturatedInteger operator*(SaturatedInteger other) {
204 if (saturated || other.saturated)
205 return SaturatedInteger{.saturated: true, .v: 0};
206 return SaturatedInteger{.saturated: false, .v: other.v * v};
207 }
208 bool saturated = true;
209 int64_t v = 0;
210};
211
212} // namespace mlir
213
214#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
215

source code of mlir/include/mlir/Dialect/Utils/StaticValueUtils.h