1 | //===- ViewLikeInterface.h - View-like operations interface ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the operation interface for view-like operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ |
14 | #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ |
15 | |
16 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinAttributes.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "mlir/IR/OpImplementation.h" |
21 | #include "mlir/IR/PatternMatch.h" |
22 | |
23 | namespace mlir { |
24 | |
25 | class OffsetSizeAndStrideOpInterface; |
26 | |
27 | namespace detail { |
28 | |
29 | LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); |
30 | |
31 | bool sameOffsetsSizesAndStrides( |
32 | OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, |
33 | llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp); |
34 | |
35 | /// Helper method to compute the number of dynamic entries of `staticVals`, |
36 | /// up to `idx`. |
37 | unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals, |
38 | unsigned idx); |
39 | |
40 | } // namespace detail |
41 | } // namespace mlir |
42 | |
43 | /// Include the generated interface declarations. |
44 | #include "mlir/Interfaces/ViewLikeInterface.h.inc" |
45 | |
46 | namespace mlir { |
47 | |
48 | /// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as |
49 | /// constant arguments. This pattern assumes that the op has a suitable builder |
50 | /// that takes a result type, a "source" operand and mixed offsets, sizes and |
51 | /// strides. |
52 | /// |
53 | /// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn` |
54 | /// returns the new result type of the op, based on the new offsets, sizes and |
55 | /// strides. `CastOpFunc` is used to generate a cast op if the result type of |
56 | /// the op has changed. |
57 | template <typename OpType, typename ResultTypeFn, typename CastOpFunc> |
58 | class OpWithOffsetSizesAndStridesConstantArgumentFolder final |
59 | : public OpRewritePattern<OpType> { |
60 | public: |
61 | using OpRewritePattern<OpType>::OpRewritePattern; |
62 | |
63 | LogicalResult matchAndRewrite(OpType op, |
64 | PatternRewriter &rewriter) const override { |
65 | SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets()); |
66 | SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes()); |
67 | SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides()); |
68 | |
69 | // No constant operands were folded, just return; |
70 | if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) && |
71 | failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) && |
72 | failed(foldDynamicIndexList(mixedStrides))) |
73 | return failure(); |
74 | |
75 | // Create the new op in canonical form. |
76 | auto resultType = |
77 | ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides); |
78 | if (!resultType) |
79 | return failure(); |
80 | auto newOp = |
81 | rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(), |
82 | mixedOffsets, mixedSizes, mixedStrides); |
83 | CastOpFunc()(rewriter, op, newOp); |
84 | |
85 | return success(); |
86 | } |
87 | }; |
88 | |
89 | /// Printer hook for custom directive in assemblyFormat. |
90 | /// |
91 | /// custom<DynamicIndexList>($values, $integers) |
92 | /// custom<DynamicIndexList>($values, $integers, type($values)) |
93 | /// |
94 | /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS |
95 | /// type `I64ArrayAttr`. Prints a list with either (1) the static integer value |
96 | /// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes` |
97 | /// is non-empty, it is expected to contain as many elements as `values` |
98 | /// indicating their types. This allows idiomatic printing of mixed value and |
99 | /// integer attributes in a list. E.g. |
100 | /// `[%arg0 : index, 7, 42, %arg42 : i32]`. |
101 | /// |
102 | /// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. |
103 | /// This notation is similar to how scalable dims are marked when defining |
104 | /// Vectors. For each value in `integers`, the corresponding `bool` in |
105 | /// `scalables` encodes whether it's a scalable index. If `scalableVals` is |
106 | /// empty then assume that all indices are non-scalable. |
107 | void printDynamicIndexList( |
108 | OpAsmPrinter &printer, Operation *op, OperandRange values, |
109 | ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(), |
110 | ArrayRef<bool> scalables = {}, |
111 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); |
112 | |
113 | /// Parser hook for custom directive in assemblyFormat. |
114 | /// |
115 | /// custom<DynamicIndexList>($values, $integers) |
116 | /// custom<DynamicIndexList>($values, $integers, type($values)) |
117 | /// |
118 | /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS |
119 | /// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer |
120 | /// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where |
121 | /// `kDynamic` encodes the position of SSA values. Add the parsed SSA values |
122 | /// to `values` in-order. If `valueTypes` is non-null, fill it with types |
123 | /// corresponding to values; otherwise the caller must handle the types. |
124 | /// |
125 | /// E.g. after parsing "[%arg0 : index, 7, 42, %arg42 : i32]": |
126 | /// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42, |
127 | /// `kDynamic`]" |
128 | /// 2. `ssa` is filled with "[%arg0, %arg1]". |
129 | /// |
130 | /// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. |
131 | /// This notation is similar to how scalable dims are marked when defining |
132 | /// Vectors. For each value in `integers`, the corresponding `bool` in |
133 | /// `scalableVals` encodes whether it's a scalable index. |
134 | ParseResult parseDynamicIndexList( |
135 | OpAsmParser &parser, |
136 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
137 | DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, |
138 | SmallVectorImpl<Type> *valueTypes = nullptr, |
139 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); |
140 | inline ParseResult parseDynamicIndexList( |
141 | OpAsmParser &parser, |
142 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
143 | DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr, |
144 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { |
145 | DenseBoolArrayAttr scalableVals = {}; |
146 | return parseDynamicIndexList(parser, values, integers, scalableVals, |
147 | valueTypes, delimiter); |
148 | } |
149 | inline ParseResult parseDynamicIndexList( |
150 | OpAsmParser &parser, |
151 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
152 | DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes, |
153 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { |
154 | DenseBoolArrayAttr scalableVals = {}; |
155 | return parseDynamicIndexList(parser, values, integers, scalableVals, |
156 | &valueTypes, delimiter); |
157 | } |
158 | inline ParseResult parseDynamicIndexList( |
159 | OpAsmParser &parser, |
160 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
161 | DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes, |
162 | DenseBoolArrayAttr &scalableVals, |
163 | AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { |
164 | |
165 | return parseDynamicIndexList(parser, values, integers, scalableVals, |
166 | valueTypes: &valueTypes, delimiter); |
167 | } |
168 | |
169 | /// Verify that a the `values` has as many elements as the number of entries in |
170 | /// `attr` for which `isDynamic` evaluates to true. |
171 | LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, |
172 | unsigned expectedNumElements, |
173 | ArrayRef<int64_t> attr, |
174 | ValueRange values); |
175 | |
176 | } // namespace mlir |
177 | |
178 | #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ |
179 | |