1 | //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the Vector dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
14 | #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
15 | |
16 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
18 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" |
19 | #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" |
20 | #include "mlir/IR/AffineMap.h" |
21 | #include "mlir/IR/Attributes.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/Dialect.h" |
24 | #include "mlir/IR/OpDefinition.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
27 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
28 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
29 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
30 | #include "mlir/Interfaces/VectorInterfaces.h" |
31 | #include "mlir/Interfaces/ViewLikeInterface.h" |
32 | #include "llvm/ADT/SetVector.h" |
33 | #include "llvm/ADT/StringExtras.h" |
34 | |
35 | // Pull in all enum type definitions and utility function declarations. |
36 | #include "mlir/Dialect/Vector/IR/VectorEnums.h.inc" |
37 | |
38 | #define GET_ATTRDEF_CLASSES |
39 | #include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc" |
40 | |
41 | namespace mlir { |
42 | class MLIRContext; |
43 | class RewritePatternSet; |
44 | |
45 | namespace arith { |
46 | enum class AtomicRMWKind : uint64_t; |
47 | } // namespace arith |
48 | |
49 | namespace vector { |
50 | class ContractionOp; |
51 | class TransferReadOp; |
52 | class TransferWriteOp; |
53 | class VectorDialect; |
54 | |
55 | namespace detail { |
56 | struct BitmaskEnumStorage; |
57 | } // namespace detail |
58 | |
59 | /// Predefined constant_mask kinds. |
60 | enum class ConstantMaskKind { AllFalse = 0, AllTrue }; |
61 | |
62 | /// Default callback to build a region with a 'vector.yield' terminator with no |
63 | /// arguments. |
64 | void buildTerminatedBody(OpBuilder &builder, Location loc); |
65 | |
66 | /// Return whether `srcType` can be broadcast to `dstVectorType` under the |
67 | /// semantics of the `vector.broadcast` op. |
68 | enum class BroadcastableToResult { |
69 | Success = 0, |
70 | SourceRankHigher = 1, |
71 | DimensionMismatch = 2, |
72 | SourceTypeNotAVector = 3 |
73 | }; |
74 | |
75 | struct VectorDim { |
76 | int64_t dim; |
77 | bool isScalable; |
78 | }; |
79 | BroadcastableToResult |
80 | isBroadcastableTo(Type srcType, VectorType dstVectorType, |
81 | std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr); |
82 | |
83 | /// Collect a set of vector-to-vector canonicalization patterns. |
84 | void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, |
85 | PatternBenefit benefit = 1); |
86 | |
87 | /// Collect a set of patterns that fold arithmetic extension on floating point |
88 | /// into vector contract for the backends with native support. |
89 | void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); |
90 | |
91 | /// Collect a set of patterns that fold elementwise op on vectors to the vector |
92 | /// dialect. |
93 | void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns); |
94 | |
95 | /// Returns the integer type required for subscripts in the vector dialect. |
96 | IntegerType getVectorSubscriptType(Builder &builder); |
97 | |
98 | /// Returns an integer array attribute containing the given values using |
99 | /// the integer type required for subscripts in the vector dialect. |
100 | ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values); |
101 | |
102 | /// Returns the value obtained by reducing the vector into a scalar using the |
103 | /// operation kind associated with a binary AtomicRMWKind op. |
104 | Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, |
105 | Location loc, Value vector); |
106 | |
107 | /// Build the default minor identity map suitable for a vector transfer. This |
108 | /// also handles the case memref<... x vector<...>> -> vector<...> in which the |
109 | /// rank of the identity map must take the vector element type into account. |
110 | AffineMap getTransferMinorIdentityMap(ShapedType shapedType, |
111 | VectorType vectorType); |
112 | |
113 | /// Return true if the transfer_write fully writes the data accessed by the |
114 | /// transfer_read. |
115 | bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read); |
116 | |
117 | /// Return true if the write op fully over-write the priorWrite transfer_write |
118 | /// op. |
119 | bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite); |
120 | |
121 | /// Return true if we can prove that the transfer operations access disjoint |
122 | /// memory, without requring the accessed tensor/memref to be the same. |
123 | /// |
124 | /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values |
125 | /// via ValueBoundsOpInterface. |
126 | bool isDisjointTransferIndices(VectorTransferOpInterface transferA, |
127 | VectorTransferOpInterface transferB, |
128 | bool testDynamicValueUsingBounds = false); |
129 | |
130 | /// Return true if we can prove that the transfer operations access disjoint |
131 | /// memory, requiring the operations to access the same tensor/memref. |
132 | /// |
133 | /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values |
134 | /// via ValueBoundsOpInterface. |
135 | bool isDisjointTransferSet(VectorTransferOpInterface transferA, |
136 | VectorTransferOpInterface transferB, |
137 | bool testDynamicValueUsingBounds = false); |
138 | |
139 | /// Returns the result value of reducing two scalar/vector values with the |
140 | /// corresponding arith operation. |
141 | Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, |
142 | Value v1, Value acc, |
143 | arith::FastMathFlagsAttr fastmath = nullptr, |
144 | Value mask = nullptr); |
145 | |
146 | /// Returns true if `attr` has "parallel" iterator type semantics. |
147 | inline bool isParallelIterator(Attribute attr) { |
148 | return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::parallel; |
149 | } |
150 | |
151 | /// Returns true if `attr` has "reduction" iterator type semantics. |
152 | inline bool isReductionIterator(Attribute attr) { |
153 | return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction; |
154 | } |
155 | |
156 | /// Returns the integer numbers in `values`. `values` are expected to be |
157 | /// constant operations. |
158 | SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values); |
159 | |
160 | /// Returns the integer numbers in `foldResults`. `foldResults` are expected to |
161 | /// be constant operations. |
162 | SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults); |
163 | |
164 | /// Convert `foldResults` into Values. Integer attributes are converted to |
165 | /// constant op. |
166 | SmallVector<Value> getAsValues(OpBuilder &builder, Location loc, |
167 | ArrayRef<OpFoldResult> foldResults); |
168 | |
169 | /// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst * |
170 | /// vector.vscale`), return the multiplier (`%cst`). Otherwise, return |
171 | /// `std::nullopt`. |
172 | std::optional<int64_t> getConstantVscaleMultiplier(Value value); |
173 | |
174 | //===----------------------------------------------------------------------===// |
175 | // Vector Masking Utilities |
176 | //===----------------------------------------------------------------------===// |
177 | |
178 | /// Infers the mask type for a transfer op given its vector type and |
179 | /// permutation map. The mask in a transfer op operation applies to the |
180 | /// tensor/buffer part of it and its type should match the vector shape |
181 | /// *before* any permutation or broadcasting. For example, |
182 | /// |
183 | /// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)> |
184 | /// |
185 | /// Has inferred mask type: |
186 | /// |
187 | /// maskType = vector<2x1xi1> |
188 | VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap); |
189 | |
190 | /// Create the vector.yield-ended region of a vector.mask op with `maskableOp` |
191 | /// as masked operation. |
192 | void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp); |
193 | |
194 | /// Creates a vector.mask operation around a maskable operation. Returns the |
195 | /// vector.mask operation if the mask provided is valid. Otherwise, returns the |
196 | /// maskable operation itself. |
197 | Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, |
198 | Value passthru = Value()); |
199 | |
200 | /// Creates a vector select operation that picks values from `newValue` or |
201 | /// `passthru` for each result vector lane based on `mask`. This utility is used |
202 | /// to propagate the pass-thru value for masked-out or expeculatively executed |
203 | /// lanes. VP intrinsics do not support pass-thru values and every mask-out lane |
204 | /// is set to poison. LLVM backends are usually able to match op + select |
205 | /// patterns and fold them into a native target instructions. |
206 | Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, |
207 | Value passthru); |
208 | |
209 | } // namespace vector |
210 | } // namespace mlir |
211 | |
212 | #define GET_OP_CLASSES |
213 | #include "mlir/Dialect/Vector/IR/VectorDialect.h.inc" |
214 | #include "mlir/Dialect/Vector/IR/VectorOps.h.inc" |
215 | |
216 | #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
217 | |