1 | //===- ViewLikeInterface.h - View-like operations interface ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the operation interface for view-like operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ |
14 | #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ |
15 | |
16 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinAttributes.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "mlir/IR/OpImplementation.h" |
21 | #include "mlir/IR/PatternMatch.h" |
22 | |
23 | namespace mlir { |
24 | |
25 | class OffsetSizeAndStrideOpInterface; |
26 | |
27 | namespace detail { |
28 | |
29 | LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); |
30 | |
31 | bool sameOffsetsSizesAndStrides( |
32 | OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, |
33 | llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp); |
34 | |
35 | /// Helper method to compute the number of dynamic entries of `staticVals`, |
36 | /// up to `idx`. |
37 | unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals, |
38 | unsigned idx); |
39 | |
40 | } // namespace detail |
41 | } // namespace mlir |
42 | |
43 | /// Include the generated interface declarations. |
44 | #include "mlir/Interfaces/ViewLikeInterface.h.inc" |
45 | |
46 | namespace mlir { |
47 | |
48 | /// Result for slice bounds verification; |
49 | struct SliceBoundsVerificationResult { |
50 | /// If set to "true", the slice bounds verification was successful. |
51 | bool isValid; |
52 | /// An error message that can be printed during op verification. |
53 | std::string errorMessage; |
54 | }; |
55 | |
56 | /// Verify that the offsets/sizes/strides-style access into the given shape |
57 | /// is in-bounds. Only static values are verified. If `generateErrorMessage` |
58 | /// is set to "true", an error message is produced that can be printed by the |
59 | /// op verifier. |
60 | SliceBoundsVerificationResult |
61 | verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets, |
62 | ArrayRef<int64_t> staticSizes, |
63 | ArrayRef<int64_t> staticStrides, |
64 | bool generateErrorMessage = false); |
65 | SliceBoundsVerificationResult verifyInBoundsSlice( |
66 | ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets, |
67 | ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides, |
68 | bool generateErrorMessage = false); |
69 | |
70 | /// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as |
71 | /// constant arguments. This pattern assumes that the op has a suitable builder |
72 | /// that takes a result type, a "source" operand and mixed offsets, sizes and |
73 | /// strides. |
74 | /// |
75 | /// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn` |
76 | /// returns the new result type of the op, based on the new offsets, sizes and |
77 | /// strides. `CastOpFunc` is used to generate a cast op if the result type of |
78 | /// the op has changed. |
79 | template <typename OpType, typename ResultTypeFn, typename CastOpFunc> |
80 | class OpWithOffsetSizesAndStridesConstantArgumentFolder final |
81 | : public OpRewritePattern<OpType> { |
82 | public: |
83 | using OpRewritePattern<OpType>::OpRewritePattern; |
84 | |
85 | LogicalResult matchAndRewrite(OpType op, |
86 | PatternRewriter &rewriter) const override { |
87 | SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets()); |
88 | SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes()); |
89 | SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides()); |
90 | |
91 | // No constant operands were folded, just return; |
92 | if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) && |
93 | failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) && |
94 | failed(foldDynamicIndexList(mixedStrides))) |
95 | return failure(); |
96 | |
97 | // Pattern does not apply if the produced op would not verify. |
98 | SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice( |
99 | cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets, |
100 | mixedSizes, mixedStrides); |
101 | if (!sliceResult.isValid) |
102 | return failure(); |
103 | |
104 | // Compute the new result type. |
105 | auto resultType = |
106 | ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides); |
107 | if (!resultType) |
108 | return failure(); |
109 | |
110 | // Create the new op in canonical form. |
111 | auto newOp = |
112 | rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(), |
113 | mixedOffsets, mixedSizes, mixedStrides); |
114 | CastOpFunc()(rewriter, op, newOp); |
115 | |
116 | return success(); |
117 | } |
118 | }; |
119 | |
120 | /// Printer hooks for custom directive in assemblyFormat. |
121 | /// |
122 | /// custom<DynamicIndexList>($values, $integers) |
123 | /// custom<DynamicIndexList>($values, $integers, type($values)) |
124 | /// |
125 | /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS type |
126 | /// `I64ArrayAttr`. Print a list where each element is either: |
127 | /// 1. the static integer value in `integers`, if it's not `kDynamic` or, |
128 | /// 2. the next value in `values`, otherwise. |
129 | /// |
130 | /// If `valueTypes` is provided, the corresponding type of each dynamic value is |
131 | /// printed. Otherwise, the type is not printed. Each type must match the type |
132 | /// of the corresponding value in `values`. `valueTypes` is redundant for |
133 | /// printing as we can retrieve the types from the actual `values`. However, |
134 | /// `valueTypes` is needed for parsing and we must keep the API symmetric for |
135 | /// parsing and printing. The type for integer elements is `i64` by default and |
136 | /// never printed. |
137 | /// |
138 | /// Integer indices can also be scalable in the context of scalable vectors, |
139 | /// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in |
140 | /// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's |
141 | /// a scalable index. If `scalableFlags` is empty then assume that all indices |
142 | /// are non-scalable. |
143 | /// |
144 | /// Examples: |
145 | /// |
146 | /// * Input: `integers = [kDynamic, 7, 42, kDynamic]`, |
147 | /// `values = [%arg0, %arg42]` and |
148 | /// `valueTypes = [index, index]` |
149 | /// prints: |
150 | /// `[%arg0 : index, 7, 42, %arg42 : i32]` |
151 | /// |
152 | /// * Input: `integers = [kDynamic, 7, 42, kDynamic]`, |
153 | /// `values = [%arg0, %arg42]` and |
154 | /// `valueTypes = []` |
155 | /// prints: |
156 | /// `[%arg0, 7, 42, %arg42]` |
157 | /// |
158 | /// * Input: `integers = [2, 4, 8]`, |
159 | /// `values = []` and |
160 | /// `scalableFlags = [false, true, false]` |
161 | /// prints: |
162 | /// `[2, [4], 8]` |
163 | /// |
164 | void printDynamicIndexList( |
165 | OpAsmPrinter &printer, Operation *op, OperandRange values, |
166 | ArrayRef<int64_t> integers, ArrayRef<bool> scalableFlags, |
167 | TypeRange valueTypes = TypeRange(), |
168 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); |
169 | inline void printDynamicIndexList( |
170 | OpAsmPrinter &printer, Operation *op, OperandRange values, |
171 | ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(), |
172 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { |
173 | return printDynamicIndexList(printer, op, values, integers, |
174 | /*scalableFlags=*/{}, valueTypes, delimiter); |
175 | } |
176 | |
177 | /// Parser hooks for custom directive in assemblyFormat. |
178 | /// |
179 | /// custom<DynamicIndexList>($values, $integers) |
180 | /// custom<DynamicIndexList>($values, $integers, type($values)) |
181 | /// |
182 | /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS |
183 | /// type `I64ArrayAttr`. Parse a mixed list where each element is either a |
184 | /// static integer or an SSA value. Fill `integers` with the integer ArrayAttr, |
185 | /// where `kDynamic` encodes the position of SSA values. Add the parsed SSA |
186 | /// values to `values` in-order. |
187 | /// |
188 | /// If `valueTypes` is provided, fill it with the types corresponding to each |
189 | /// value in `values`. Otherwise, the caller must handle the types and parsing |
190 | /// will fail if the type of the value is found (e.g., `[%arg0 : index, 3, %arg1 |
191 | /// : index]`). |
192 | /// |
193 | /// Integer indices can also be scalable in the context of scalable vectors, |
194 | /// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in |
195 | /// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's |
196 | /// a scalable index. |
197 | /// |
198 | /// Examples: |
199 | /// |
200 | /// * After parsing "[%arg0 : index, 7, 42, %arg42 : i32]": |
201 | /// 1. `result` is filled with `[kDynamic, 7, 42, kDynamic]` |
202 | /// 2. `values` is filled with "[%arg0, %arg1]". |
203 | /// 3. `scalableFlags` is filled with `[false, true, false]`. |
204 | /// |
205 | /// * After parsing `[2, [4], 8]`: |
206 | /// 1. `result` is filled with `[2, 4, 8]` |
207 | /// 2. `values` is empty. |
208 | /// 3. `scalableFlags` is filled with `[false, true, false]`. |
209 | /// |
210 | ParseResult parseDynamicIndexList( |
211 | OpAsmParser &parser, |
212 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
213 | DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, |
214 | SmallVectorImpl<Type> *valueTypes = nullptr, |
215 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); |
216 | inline ParseResult parseDynamicIndexList( |
217 | OpAsmParser &parser, |
218 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
219 | DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr, |
220 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { |
221 | DenseBoolArrayAttr scalableFlags; |
222 | return parseDynamicIndexList(parser, values, integers, scalableFlags, |
223 | valueTypes, delimiter); |
224 | } |
225 | |
226 | /// Verify that a the `values` has as many elements as the number of entries in |
227 | /// `attr` for which `isDynamic` evaluates to true. |
228 | LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, |
229 | unsigned expectedNumElements, |
230 | ArrayRef<int64_t> attr, |
231 | ValueRange values); |
232 | |
233 | } // namespace mlir |
234 | |
235 | #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ |
236 | |