1//===- ViewLikeInterface.h - View-like operations interface ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the operation interface for view-like operations.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
14#define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
15
16#include "mlir/Dialect/Utils/StaticValueUtils.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/OpImplementation.h"
21#include "mlir/IR/PatternMatch.h"
22
23namespace mlir {
24
25class OffsetSizeAndStrideOpInterface;
26
27namespace detail {
28
29LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
30
31bool sameOffsetsSizesAndStrides(
32 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
33 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp);
34
35/// Helper method to compute the number of dynamic entries of `staticVals`,
36/// up to `idx`.
37unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
38 unsigned idx);
39
40} // namespace detail
41} // namespace mlir
42
43/// Include the generated interface declarations.
44#include "mlir/Interfaces/ViewLikeInterface.h.inc"
45
46namespace mlir {
47
48/// Result for slice bounds verification;
49struct SliceBoundsVerificationResult {
50 /// If set to "true", the slice bounds verification was successful.
51 bool isValid;
52 /// An error message that can be printed during op verification.
53 std::string errorMessage;
54};
55
56/// Verify that the offsets/sizes/strides-style access into the given shape
57/// is in-bounds. Only static values are verified. If `generateErrorMessage`
58/// is set to "true", an error message is produced that can be printed by the
59/// op verifier.
60SliceBoundsVerificationResult
61verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
62 ArrayRef<int64_t> staticSizes,
63 ArrayRef<int64_t> staticStrides,
64 bool generateErrorMessage = false);
65SliceBoundsVerificationResult verifyInBoundsSlice(
66 ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
67 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
68 bool generateErrorMessage = false);
69
70/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
71/// constant arguments. This pattern assumes that the op has a suitable builder
72/// that takes a result type, a "source" operand and mixed offsets, sizes and
73/// strides.
74///
75/// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn`
76/// returns the new result type of the op, based on the new offsets, sizes and
77/// strides. `CastOpFunc` is used to generate a cast op if the result type of
78/// the op has changed.
79template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
80class OpWithOffsetSizesAndStridesConstantArgumentFolder final
81 : public OpRewritePattern<OpType> {
82public:
83 using OpRewritePattern<OpType>::OpRewritePattern;
84
85 LogicalResult matchAndRewrite(OpType op,
86 PatternRewriter &rewriter) const override {
87 SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
88 SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
89 SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
90
91 // No constant operands were folded, just return;
92 if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
93 failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
94 failed(foldDynamicIndexList(mixedStrides)))
95 return failure();
96
97 // Pattern does not apply if the produced op would not verify.
98 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
99 cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
100 mixedSizes, mixedStrides);
101 if (!sliceResult.isValid)
102 return failure();
103
104 // Compute the new result type.
105 auto resultType =
106 ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
107 if (!resultType)
108 return failure();
109
110 // Create the new op in canonical form.
111 auto newOp =
112 rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
113 mixedOffsets, mixedSizes, mixedStrides);
114 CastOpFunc()(rewriter, op, newOp);
115
116 return success();
117 }
118};
119
120/// Printer hooks for custom directive in assemblyFormat.
121///
122/// custom<DynamicIndexList>($values, $integers)
123/// custom<DynamicIndexList>($values, $integers, type($values))
124///
125/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS type
126/// `I64ArrayAttr`. Print a list where each element is either:
127/// 1. the static integer value in `integers`, if it's not `kDynamic` or,
128/// 2. the next value in `values`, otherwise.
129///
130/// If `valueTypes` is provided, the corresponding type of each dynamic value is
131/// printed. Otherwise, the type is not printed. Each type must match the type
132/// of the corresponding value in `values`. `valueTypes` is redundant for
133/// printing as we can retrieve the types from the actual `values`. However,
134/// `valueTypes` is needed for parsing and we must keep the API symmetric for
135/// parsing and printing. The type for integer elements is `i64` by default and
136/// never printed.
137///
138/// Integer indices can also be scalable in the context of scalable vectors,
139/// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in
140/// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's
141/// a scalable index. If `scalableFlags` is empty then assume that all indices
142/// are non-scalable.
143///
144/// Examples:
145///
146/// * Input: `integers = [kDynamic, 7, 42, kDynamic]`,
147/// `values = [%arg0, %arg42]` and
148/// `valueTypes = [index, index]`
149/// prints:
150/// `[%arg0 : index, 7, 42, %arg42 : i32]`
151///
152/// * Input: `integers = [kDynamic, 7, 42, kDynamic]`,
153/// `values = [%arg0, %arg42]` and
154/// `valueTypes = []`
155/// prints:
156/// `[%arg0, 7, 42, %arg42]`
157///
158/// * Input: `integers = [2, 4, 8]`,
159/// `values = []` and
160/// `scalableFlags = [false, true, false]`
161/// prints:
162/// `[2, [4], 8]`
163///
164void printDynamicIndexList(
165 OpAsmPrinter &printer, Operation *op, OperandRange values,
166 ArrayRef<int64_t> integers, ArrayRef<bool> scalableFlags,
167 TypeRange valueTypes = TypeRange(),
168 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
169inline void printDynamicIndexList(
170 OpAsmPrinter &printer, Operation *op, OperandRange values,
171 ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
172 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
173 return printDynamicIndexList(printer, op, values, integers,
174 /*scalableFlags=*/{}, valueTypes, delimiter);
175}
176
177/// Parser hooks for custom directive in assemblyFormat.
178///
179/// custom<DynamicIndexList>($values, $integers)
180/// custom<DynamicIndexList>($values, $integers, type($values))
181///
182/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
183/// type `I64ArrayAttr`. Parse a mixed list where each element is either a
184/// static integer or an SSA value. Fill `integers` with the integer ArrayAttr,
185/// where `kDynamic` encodes the position of SSA values. Add the parsed SSA
186/// values to `values` in-order.
187///
188/// If `valueTypes` is provided, fill it with the types corresponding to each
189/// value in `values`. Otherwise, the caller must handle the types and parsing
190/// will fail if the type of the value is found (e.g., `[%arg0 : index, 3, %arg1
191/// : index]`).
192///
193/// Integer indices can also be scalable in the context of scalable vectors,
194/// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in
195/// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's
196/// a scalable index.
197///
198/// Examples:
199///
200/// * After parsing "[%arg0 : index, 7, 42, %arg42 : i32]":
201/// 1. `result` is filled with `[kDynamic, 7, 42, kDynamic]`
202/// 2. `values` is filled with "[%arg0, %arg1]".
203/// 3. `scalableFlags` is filled with `[false, true, false]`.
204///
205/// * After parsing `[2, [4], 8]`:
206/// 1. `result` is filled with `[2, 4, 8]`
207/// 2. `values` is empty.
208/// 3. `scalableFlags` is filled with `[false, true, false]`.
209///
210ParseResult parseDynamicIndexList(
211 OpAsmParser &parser,
212 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
213 DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
214 SmallVectorImpl<Type> *valueTypes = nullptr,
215 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
216inline ParseResult parseDynamicIndexList(
217 OpAsmParser &parser,
218 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
219 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
220 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
221 DenseBoolArrayAttr scalableFlags;
222 return parseDynamicIndexList(parser, values, integers, scalableFlags,
223 valueTypes, delimiter);
224}
225
226/// Verify that a the `values` has as many elements as the number of entries in
227/// `attr` for which `isDynamic` evaluates to true.
228LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
229 unsigned expectedNumElements,
230 ArrayRef<int64_t> attr,
231 ValueRange values);
232
233} // namespace mlir
234
235#endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
236

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/include/mlir/Interfaces/ViewLikeInterface.h