1 | //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the Vector dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
14 | #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
15 | |
16 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
18 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" |
19 | #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" |
20 | #include "mlir/IR/AffineMap.h" |
21 | #include "mlir/IR/Attributes.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/Dialect.h" |
24 | #include "mlir/IR/OpDefinition.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
27 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
28 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
29 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
30 | #include "mlir/Interfaces/VectorInterfaces.h" |
31 | #include "mlir/Interfaces/ViewLikeInterface.h" |
32 | #include "llvm/ADT/SetVector.h" |
33 | #include "llvm/ADT/StringExtras.h" |
34 | |
35 | // Pull in all enum type definitions and utility function declarations. |
36 | #include "mlir/Dialect/Vector/IR/VectorEnums.h.inc" |
37 | |
38 | #define GET_ATTRDEF_CLASSES |
39 | #include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc" |
40 | |
41 | namespace mlir { |
42 | class MLIRContext; |
43 | class RewritePatternSet; |
44 | |
45 | namespace arith { |
46 | enum class AtomicRMWKind : uint64_t; |
47 | } // namespace arith |
48 | |
49 | namespace vector { |
50 | class ContractionOp; |
51 | class TransferReadOp; |
52 | class TransferWriteOp; |
53 | class VectorDialect; |
54 | |
55 | namespace detail { |
56 | struct BitmaskEnumStorage; |
57 | } // namespace detail |
58 | |
59 | /// Default callback to build a region with a 'vector.yield' terminator with no |
60 | /// arguments. |
61 | void buildTerminatedBody(OpBuilder &builder, Location loc); |
62 | |
63 | /// Return whether `srcType` can be broadcast to `dstVectorType` under the |
64 | /// semantics of the `vector.broadcast` op. |
65 | enum class BroadcastableToResult { |
66 | Success = 0, |
67 | SourceRankHigher = 1, |
68 | DimensionMismatch = 2, |
69 | SourceTypeNotAVector = 3 |
70 | }; |
71 | BroadcastableToResult |
72 | isBroadcastableTo(Type srcType, VectorType dstVectorType, |
73 | std::pair<int, int> *mismatchingDims = nullptr); |
74 | |
75 | /// Collect a set of vector-to-vector canonicalization patterns. |
76 | void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, |
77 | PatternBenefit benefit = 1); |
78 | |
79 | /// Collect a set of patterns that fold arithmetic extension on floating point |
80 | /// into vector contract for the backends with native support. |
81 | void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); |
82 | |
83 | /// Returns the integer type required for subscripts in the vector dialect. |
84 | IntegerType getVectorSubscriptType(Builder &builder); |
85 | |
86 | /// Returns an integer array attribute containing the given values using |
87 | /// the integer type required for subscripts in the vector dialect. |
88 | ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values); |
89 | |
90 | /// Returns the value obtained by reducing the vector into a scalar using the |
91 | /// operation kind associated with a binary AtomicRMWKind op. |
92 | Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, |
93 | Location loc, Value vector); |
94 | |
95 | /// Build the default minor identity map suitable for a vector transfer. This |
96 | /// also handles the case memref<... x vector<...>> -> vector<...> in which the |
97 | /// rank of the identity map must take the vector element type into account. |
98 | AffineMap getTransferMinorIdentityMap(ShapedType shapedType, |
99 | VectorType vectorType); |
100 | |
101 | /// Return true if the transfer_write fully writes the data accessed by the |
102 | /// transfer_read. |
103 | bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read); |
104 | |
105 | /// Return true if the write op fully over-write the priorWrite transfer_write |
106 | /// op. |
107 | bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite); |
108 | |
109 | /// Return true if we can prove that the transfer operations access disjoint |
110 | /// memory, without requring the accessed tensor/memref to be the same. |
111 | /// |
112 | /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values |
113 | /// via ValueBoundsOpInterface. |
114 | bool isDisjointTransferIndices(VectorTransferOpInterface transferA, |
115 | VectorTransferOpInterface transferB, |
116 | bool testDynamicValueUsingBounds = false); |
117 | |
118 | /// Return true if we can prove that the transfer operations access disjoint |
119 | /// memory, requiring the operations to access the same tensor/memref. |
120 | /// |
121 | /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values |
122 | /// via ValueBoundsOpInterface. |
123 | bool isDisjointTransferSet(VectorTransferOpInterface transferA, |
124 | VectorTransferOpInterface transferB, |
125 | bool testDynamicValueUsingBounds = false); |
126 | |
127 | /// Returns the result value of reducing two scalar/vector values with the |
128 | /// corresponding arith operation. |
129 | Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, |
130 | Value v1, Value acc, |
131 | arith::FastMathFlagsAttr fastmath = nullptr, |
132 | Value mask = nullptr); |
133 | |
134 | /// Returns true if `attr` has "parallel" iterator type semantics. |
135 | inline bool isParallelIterator(Attribute attr) { |
136 | return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::parallel; |
137 | } |
138 | |
139 | /// Returns true if `attr` has "reduction" iterator type semantics. |
140 | inline bool isReductionIterator(Attribute attr) { |
141 | return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction; |
142 | } |
143 | |
144 | /// Returns the integer numbers in `values`. `values` are expected to be |
145 | /// constant operations. |
146 | SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values); |
147 | |
148 | /// Returns the integer numbers in `foldResults`. `foldResults` are expected to |
149 | /// be constant operations. |
150 | SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults); |
151 | |
152 | /// Convert `foldResults` into Values. Integer attributes are converted to |
153 | /// constant op. |
154 | SmallVector<Value> getAsValues(OpBuilder &builder, Location loc, |
155 | ArrayRef<OpFoldResult> foldResults); |
156 | |
157 | /// Returns the constant index ops in `values`. `values` are expected to be |
158 | /// constant operations. |
159 | SmallVector<arith::ConstantIndexOp> |
160 | getAsConstantIndexOps(ArrayRef<Value> values); |
161 | |
162 | //===----------------------------------------------------------------------===// |
163 | // Vector Masking Utilities |
164 | //===----------------------------------------------------------------------===// |
165 | |
166 | /// Infers the mask type for a transfer op given its vector type and |
167 | /// permutation map. The mask in a transfer op operation applies to the |
168 | /// tensor/buffer part of it and its type should match the vector shape |
169 | /// *before* any permutation or broadcasting. For example, |
170 | /// |
171 | /// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)> |
172 | /// |
173 | /// Has inferred mask type: |
174 | /// |
175 | /// maskType = vector<2x1xi1> |
176 | VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap); |
177 | |
178 | /// Create the vector.yield-ended region of a vector.mask op with `maskableOp` |
179 | /// as masked operation. |
180 | void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp); |
181 | |
182 | /// Creates a vector.mask operation around a maskable operation. Returns the |
183 | /// vector.mask operation if the mask provided is valid. Otherwise, returns the |
184 | /// maskable operation itself. |
185 | Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, |
186 | Value passthru = Value()); |
187 | |
188 | /// Creates a vector select operation that picks values from `newValue` or |
189 | /// `passthru` for each result vector lane based on `mask`. This utility is used |
190 | /// to propagate the pass-thru value for masked-out or expeculatively executed |
191 | /// lanes. VP intrinsics do not support pass-thru values and every mask-out lane |
192 | /// is set to poison. LLVM backends are usually able to match op + select |
193 | /// patterns and fold them into a native target instructions. |
194 | Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, |
195 | Value passthru); |
196 | |
197 | } // namespace vector |
198 | } // namespace mlir |
199 | |
200 | #define GET_OP_CLASSES |
201 | #include "mlir/Dialect/Vector/IR/VectorDialect.h.inc" |
202 | #include "mlir/Dialect/Vector/IR/VectorOps.h.inc" |
203 | |
204 | #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
205 | |