1//===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
10#define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
11
12#include "mlir/Bytecode/BytecodeOpInterface.h"
13#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/Dialect.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/IR/OpImplementation.h"
18#include "mlir/Interfaces/CastInterfaces.h"
19#include "mlir/Interfaces/ControlFlowInterfaces.h"
20#include "mlir/Interfaces/DestinationStyleOpInterface.h"
21#include "mlir/Interfaces/InferTypeOpInterface.h"
22#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
23#include "mlir/Interfaces/ShapedOpInterfaces.h"
24#include "mlir/Interfaces/SideEffectInterfaces.h"
25#include "mlir/Interfaces/TilingInterface.h"
26#include "mlir/Interfaces/ViewLikeInterface.h"
27
28//===----------------------------------------------------------------------===//
29// Tensor Dialect Helpers
30//===----------------------------------------------------------------------===//
31
32namespace mlir {
33
34/// Return the list of Range (i.e. offset, size, stride). Each Range
35/// entry contains either the dynamic value or a ConstantIndexOp constructed
36/// with `b` at location `loc`.
37SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
38 OpBuilder &b, Location loc);
39
40} // namespace mlir
41
42//===----------------------------------------------------------------------===//
43// Tensor Dialect
44//===----------------------------------------------------------------------===//
45
46#include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
47
48//===----------------------------------------------------------------------===//
49// Tensor Dialect Operations
50//===----------------------------------------------------------------------===//
51
52#define GET_OP_CLASSES
53#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
54
55//===----------------------------------------------------------------------===//
56// Tensor Dialect Helpers
57//===----------------------------------------------------------------------===//
58
59namespace mlir {
60namespace tensor {
61
62/// Returns true if `target` is a ranked tensor type that preserves static
63/// information available in the `source` ranked tensor type.
64bool preservesStaticInformation(Type source, Type target);
65
66/// Determines whether tensor::CastOp casts to a more dynamic version of the
67/// source tensor. This is useful to fold a tensor.cast into a consuming op and
68/// implement canonicalization patterns for ops in different dialects that may
69/// consume the results of tensor.cast operations. Such foldable tensor.cast
70/// operations are typically inserted as `extract_slice` ops and are
71/// canonicalized, to preserve the type compatibility of their uses.
72///
73/// Returns true when all conditions are met:
74/// 1. source and result are ranked tensors with same element type and rank.
75/// 2. the tensor type has more static information than the result
76///
77/// Example:
78/// ```mlir
79/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
80/// %2 = consumer %1 ... : tensor<?x?xf32> ...
81/// ```
82///
83/// folds into:
84///
85/// ```mlir
86/// %2 = consumer %0 ... : tensor<8x16xf32> ...
87/// ```
88bool canFoldIntoConsumerOp(CastOp castOp);
89
90/// Determines whether the tensor::CastOp casts to a more static version of the
91/// source tensor. This is useful to fold into a producing op and implement
92/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
93/// being from different dialects. Returns true when all conditions are met:
94/// 1. source and result and ranked tensors with same element type and rank.
95/// 2. the result type has more static information than the source.
96///
97/// Example:
98/// ```mlir
99/// %1 = producer ... : tensor<?x?xf32>
100/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
101/// ```
102///
103/// can be canonicalized to :
104///
105/// ```mlir
106/// %2 = producer ... : tensor<8x16xf32>
107/// ```
108/// Not all ops might be canonicalizable this way, but for those that can be,
109/// this method provides a check that it is worth doing the canonicalization.
110bool canFoldIntoProducerOp(CastOp castOp);
111
112/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
113/// that can be folded.
114LogicalResult foldTensorCast(Operation *op);
115
116/// Return the dimension of the given tensor value.
117OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value,
118 int64_t dim);
119
120/// Return the dimensions of the given tensor value.
121SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
122 Value value);
123
124/// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
125/// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
126/// to that of `targetType`.
127Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
128 Value tensor,
129 RankedTensorType targetType);
130
131/// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
132/// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with
133/// rank increased to that of `dest`, obtained by inserting `tensor` into `dest`
134/// at the canonical [0 .. 0] position.
135Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
136 Value tensor, Value dest);
137
138/// This is a helper function for DestinationStyleOpInterface. If there is a
139/// destination operand for the given OpResult, return that operand. Otherwise,
140/// return an empty tensor (`tensor.empty`) with the shape of the OpResult.
141/// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface.
142FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
143 OpResult opResult);
144
145/// This is a helper function for DestinationStyleOpInterface. Get or create
146/// destinations for every tensor OpResult of the given op.
147LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
148 SmallVector<Value> &result);
149
150/// Tests if types are the same when ignoring encoding on ranked tensors.
151bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
152
153/// Function to control the folding of constant and extract slice.
154using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
155
156/// Patterns to fold the extract slice op with its constant operand.
157void populateFoldConstantExtractSlicePatterns(
158 RewritePatternSet &patterns,
159 const ControlConstantExtractSliceFusionFn &controlFn =
160 [](ExtractSliceOp op) {
161 // Disable by default because the folding can generate a large
162 // constant tensor, which would affect the compile time and storage.
163 return false;
164 });
165
166} // namespace tensor
167} // namespace mlir
168
169#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
170

source code of mlir/include/mlir/Dialect/Tensor/IR/Tensor.h