1 | //===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_ |
10 | #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_ |
11 | |
12 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
13 | #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
14 | #include "mlir/IR/BuiltinTypes.h" |
15 | #include "mlir/IR/Dialect.h" |
16 | #include "mlir/IR/OpDefinition.h" |
17 | #include "mlir/IR/OpImplementation.h" |
18 | #include "mlir/Interfaces/CastInterfaces.h" |
19 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
20 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
21 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
22 | #include "mlir/Interfaces/ParallelCombiningOpInterface.h" |
23 | #include "mlir/Interfaces/ShapedOpInterfaces.h" |
24 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
25 | #include "mlir/Interfaces/TilingInterface.h" |
26 | #include "mlir/Interfaces/ViewLikeInterface.h" |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // Tensor Dialect Helpers |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | namespace mlir { |
33 | |
34 | /// Return the list of Range (i.e. offset, size, stride). Each Range |
35 | /// entry contains either the dynamic value or a ConstantIndexOp constructed |
36 | /// with `b` at location `loc`. |
37 | SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op, |
38 | OpBuilder &b, Location loc); |
39 | |
40 | } // namespace mlir |
41 | |
42 | //===----------------------------------------------------------------------===// |
43 | // Tensor Dialect |
44 | //===----------------------------------------------------------------------===// |
45 | |
46 | #include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc" |
47 | |
48 | //===----------------------------------------------------------------------===// |
49 | // Tensor Dialect Operations |
50 | //===----------------------------------------------------------------------===// |
51 | |
52 | #define GET_OP_CLASSES |
53 | #include "mlir/Dialect/Tensor/IR/TensorOps.h.inc" |
54 | |
55 | //===----------------------------------------------------------------------===// |
56 | // Tensor Dialect Helpers |
57 | //===----------------------------------------------------------------------===// |
58 | |
59 | namespace mlir { |
60 | namespace tensor { |
61 | |
62 | /// Returns true if `target` is a ranked tensor type that preserves static |
63 | /// information available in the `source` ranked tensor type. |
64 | bool preservesStaticInformation(Type source, Type target); |
65 | |
66 | /// Determines whether tensor::CastOp casts to a more dynamic version of the |
67 | /// source tensor. This is useful to fold a tensor.cast into a consuming op and |
68 | /// implement canonicalization patterns for ops in different dialects that may |
69 | /// consume the results of tensor.cast operations. Such foldable tensor.cast |
70 | /// operations are typically inserted as `extract_slice` ops and are |
71 | /// canonicalized, to preserve the type compatibility of their uses. |
72 | /// |
73 | /// Returns true when all conditions are met: |
74 | /// 1. source and result are ranked tensors with same element type and rank. |
75 | /// 2. the tensor type has more static information than the result |
76 | /// |
77 | /// Example: |
78 | /// ```mlir |
79 | /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> |
80 | /// %2 = consumer %1 ... : tensor<?x?xf32> ... |
81 | /// ``` |
82 | /// |
83 | /// folds into: |
84 | /// |
85 | /// ```mlir |
86 | /// %2 = consumer %0 ... : tensor<8x16xf32> ... |
87 | /// ``` |
88 | bool canFoldIntoConsumerOp(CastOp castOp); |
89 | |
90 | /// Determines whether the tensor::CastOp casts to a more static version of the |
91 | /// source tensor. This is useful to fold into a producing op and implement |
92 | /// canonicaliation patterns with the `tensor.cast` op as the root, but producer |
93 | /// being from different dialects. Returns true when all conditions are met: |
94 | /// 1. source and result and ranked tensors with same element type and rank. |
95 | /// 2. the result type has more static information than the source. |
96 | /// |
97 | /// Example: |
98 | /// ```mlir |
99 | /// %1 = producer ... : tensor<?x?xf32> |
100 | /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32> |
101 | /// ``` |
102 | /// |
103 | /// can be canonicalized to : |
104 | /// |
105 | /// ```mlir |
106 | /// %2 = producer ... : tensor<8x16xf32> |
107 | /// ``` |
108 | /// Not all ops might be canonicalizable this way, but for those that can be, |
109 | /// this method provides a check that it is worth doing the canonicalization. |
110 | bool canFoldIntoProducerOp(CastOp castOp); |
111 | |
112 | /// Performs folding of any operand of `op` if it comes from a tensor::CastOp |
113 | /// that can be folded. |
114 | LogicalResult foldTensorCast(Operation *op); |
115 | |
116 | /// Return the dimension of the given tensor value. |
117 | OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, |
118 | int64_t dim); |
119 | |
120 | /// Return the dimensions of the given tensor value. |
121 | SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc, |
122 | Value value); |
123 | |
124 | /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and |
125 | /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor` |
126 | /// to that of `targetType`. |
127 | Value (OpBuilder &b, Location loc, |
128 | Value tensor, |
129 | RankedTensorType targetType); |
130 | |
131 | /// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and |
132 | /// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with |
133 | /// rank increased to that of `dest`, obtained by inserting `tensor` into `dest` |
134 | /// at the canonical [0 .. 0] position. |
135 | Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, |
136 | Value tensor, Value dest); |
137 | |
138 | /// This is a helper function for DestinationStyleOpInterface. If there is a |
139 | /// destination operand for the given OpResult, return that operand. Otherwise, |
140 | /// return an empty tensor (`tensor.empty`) with the shape of the OpResult. |
141 | /// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface. |
142 | FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc, |
143 | OpResult opResult); |
144 | |
145 | /// This is a helper function for DestinationStyleOpInterface. Get or create |
146 | /// destinations for every tensor OpResult of the given op. |
147 | LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, |
148 | SmallVector<Value> &result); |
149 | |
150 | /// Tests if types are the same when ignoring encoding on ranked tensors. |
151 | bool isSameTypeWithoutEncoding(Type tp1, Type tp2); |
152 | |
153 | /// Function to control the folding of constant and extract slice. |
154 | using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>; |
155 | |
156 | /// Patterns to fold the extract slice op with its constant operand. |
157 | void ( |
158 | RewritePatternSet &patterns, |
159 | const ControlConstantExtractSliceFusionFn &controlFn = |
160 | [](ExtractSliceOp op) { |
161 | // Disable by default because the folding can generate a large |
162 | // constant tensor, which would affect the compile time and storage. |
163 | return false; |
164 | }); |
165 | |
166 | } // namespace tensor |
167 | } // namespace mlir |
168 | |
169 | #endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_ |
170 | |