1//===- ViewLikeInterface.h - View-like operations interface ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the operation interface for view-like operations.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
14#define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
15
16#include "mlir/Dialect/Utils/StaticValueUtils.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/OpImplementation.h"
21#include "mlir/IR/PatternMatch.h"
22
23namespace mlir {
24
25class OffsetSizeAndStrideOpInterface;
26
27namespace detail {
28
29LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
30
31bool sameOffsetsSizesAndStrides(
32 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
33 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp);
34
35/// Helper method to compute the number of dynamic entries of `staticVals`,
36/// up to `idx`.
37unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
38 unsigned idx);
39
40} // namespace detail
41} // namespace mlir
42
43/// Include the generated interface declarations.
44#include "mlir/Interfaces/ViewLikeInterface.h.inc"
45
46namespace mlir {
47
48/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
49/// constant arguments. This pattern assumes that the op has a suitable builder
50/// that takes a result type, a "source" operand and mixed offsets, sizes and
51/// strides.
52///
53/// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn`
54/// returns the new result type of the op, based on the new offsets, sizes and
55/// strides. `CastOpFunc` is used to generate a cast op if the result type of
56/// the op has changed.
57template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
58class OpWithOffsetSizesAndStridesConstantArgumentFolder final
59 : public OpRewritePattern<OpType> {
60public:
61 using OpRewritePattern<OpType>::OpRewritePattern;
62
63 LogicalResult matchAndRewrite(OpType op,
64 PatternRewriter &rewriter) const override {
65 SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
66 SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
67 SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
68
69 // No constant operands were folded, just return;
70 if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
71 failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
72 failed(foldDynamicIndexList(mixedStrides)))
73 return failure();
74
75 // Create the new op in canonical form.
76 auto resultType =
77 ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
78 if (!resultType)
79 return failure();
80 auto newOp =
81 rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
82 mixedOffsets, mixedSizes, mixedStrides);
83 CastOpFunc()(rewriter, op, newOp);
84
85 return success();
86 }
87};
88
89/// Printer hook for custom directive in assemblyFormat.
90///
91/// custom<DynamicIndexList>($values, $integers)
92/// custom<DynamicIndexList>($values, $integers, type($values))
93///
94/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
95/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value
96/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
97/// is non-empty, it is expected to contain as many elements as `values`
98/// indicating their types. This allows idiomatic printing of mixed value and
99/// integer attributes in a list. E.g.
100/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
101///
102/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
103/// This notation is similar to how scalable dims are marked when defining
104/// Vectors. For each value in `integers`, the corresponding `bool` in
105/// `scalables` encodes whether it's a scalable index. If `scalableVals` is
106/// empty then assume that all indices are non-scalable.
107void printDynamicIndexList(
108 OpAsmPrinter &printer, Operation *op, OperandRange values,
109 ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
110 ArrayRef<bool> scalables = {},
111 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
112
113/// Parser hook for custom directive in assemblyFormat.
114///
115/// custom<DynamicIndexList>($values, $integers)
116/// custom<DynamicIndexList>($values, $integers, type($values))
117///
118/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
119/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer
120/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where
121/// `kDynamic` encodes the position of SSA values. Add the parsed SSA values
122/// to `values` in-order. If `valueTypes` is non-null, fill it with types
123/// corresponding to values; otherwise the caller must handle the types.
124///
125/// E.g. after parsing "[%arg0 : index, 7, 42, %arg42 : i32]":
126/// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42,
127/// `kDynamic`]"
128/// 2. `ssa` is filled with "[%arg0, %arg1]".
129///
130/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
131/// This notation is similar to how scalable dims are marked when defining
132/// Vectors. For each value in `integers`, the corresponding `bool` in
133/// `scalableVals` encodes whether it's a scalable index.
134ParseResult parseDynamicIndexList(
135 OpAsmParser &parser,
136 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
137 DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
138 SmallVectorImpl<Type> *valueTypes = nullptr,
139 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
140inline ParseResult parseDynamicIndexList(
141 OpAsmParser &parser,
142 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
143 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
144 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
145 DenseBoolArrayAttr scalableVals = {};
146 return parseDynamicIndexList(parser, values, integers, scalableVals,
147 valueTypes, delimiter);
148}
149inline ParseResult parseDynamicIndexList(
150 OpAsmParser &parser,
151 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
152 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
153 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
154 DenseBoolArrayAttr scalableVals = {};
155 return parseDynamicIndexList(parser, values, integers, scalableVals,
156 &valueTypes, delimiter);
157}
158inline ParseResult parseDynamicIndexList(
159 OpAsmParser &parser,
160 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
161 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
162 DenseBoolArrayAttr &scalableVals,
163 AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
164
165 return parseDynamicIndexList(parser, values, integers, scalableVals,
166 valueTypes: &valueTypes, delimiter);
167}
168
169/// Verify that a the `values` has as many elements as the number of entries in
170/// `attr` for which `isDynamic` evaluates to true.
171LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
172 unsigned expectedNumElements,
173 ArrayRef<int64_t> attr,
174 ValueRange values);
175
176} // namespace mlir
177
178#endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
179

source code of mlir/include/mlir/Interfaces/ViewLikeInterface.h