| 1 | //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the Vector dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
| 14 | #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
| 15 | |
| 16 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
| 17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 18 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" |
| 19 | #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" |
| 20 | #include "mlir/IR/AffineMap.h" |
| 21 | #include "mlir/IR/Attributes.h" |
| 22 | #include "mlir/IR/BuiltinTypes.h" |
| 23 | #include "mlir/IR/Dialect.h" |
| 24 | #include "mlir/IR/OpDefinition.h" |
| 25 | #include "mlir/IR/PatternMatch.h" |
| 26 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| 27 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
| 28 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
| 29 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 30 | #include "mlir/Interfaces/VectorInterfaces.h" |
| 31 | #include "mlir/Interfaces/ViewLikeInterface.h" |
| 32 | #include "llvm/ADT/SetVector.h" |
| 33 | #include "llvm/ADT/StringExtras.h" |
| 34 | |
| 35 | // Pull in all enum type definitions and utility function declarations. |
| 36 | #include "mlir/Dialect/Vector/IR/VectorEnums.h.inc" |
| 37 | |
| 38 | #define GET_ATTRDEF_CLASSES |
| 39 | #include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc" |
| 40 | |
| 41 | namespace mlir { |
| 42 | class MLIRContext; |
| 43 | class RewritePatternSet; |
| 44 | |
| 45 | namespace arith { |
| 46 | enum class AtomicRMWKind : uint64_t; |
| 47 | } // namespace arith |
| 48 | |
| 49 | namespace vector { |
| 50 | class ContractionOp; |
| 51 | class TransferReadOp; |
| 52 | class TransferWriteOp; |
| 53 | class VectorDialect; |
| 54 | |
| 55 | namespace detail { |
| 56 | struct BitmaskEnumStorage; |
| 57 | } // namespace detail |
| 58 | |
| 59 | /// Predefined constant_mask kinds. |
| 60 | enum class ConstantMaskKind { AllFalse = 0, AllTrue }; |
| 61 | |
| 62 | /// Default callback to build a region with a 'vector.yield' terminator with no |
| 63 | /// arguments. |
| 64 | void buildTerminatedBody(OpBuilder &builder, Location loc); |
| 65 | |
| 66 | /// Return whether `srcType` can be broadcast to `dstVectorType` under the |
| 67 | /// semantics of the `vector.broadcast` op. |
| 68 | enum class BroadcastableToResult { |
| 69 | Success = 0, |
| 70 | SourceRankHigher = 1, |
| 71 | DimensionMismatch = 2, |
| 72 | SourceTypeNotAVector = 3 |
| 73 | }; |
| 74 | |
| 75 | struct VectorDim { |
| 76 | int64_t dim; |
| 77 | bool isScalable; |
| 78 | }; |
| 79 | BroadcastableToResult |
| 80 | isBroadcastableTo(Type srcType, VectorType dstVectorType, |
| 81 | std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr); |
| 82 | |
| 83 | /// Collect a set of vector-to-vector canonicalization patterns. |
| 84 | void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, |
| 85 | PatternBenefit benefit = 1); |
| 86 | |
| 87 | /// Collect a set of patterns that fold arithmetic extension on floating point |
| 88 | /// into vector contract for the backends with native support. |
| 89 | void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); |
| 90 | |
| 91 | /// Collect a set of patterns that fold elementwise op on vectors to the vector |
| 92 | /// dialect. |
| 93 | void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns); |
| 94 | |
| 95 | /// Returns the integer type required for subscripts in the vector dialect. |
| 96 | IntegerType getVectorSubscriptType(Builder &builder); |
| 97 | |
| 98 | /// Returns an integer array attribute containing the given values using |
| 99 | /// the integer type required for subscripts in the vector dialect. |
| 100 | ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values); |
| 101 | |
| 102 | /// Returns the value obtained by reducing the vector into a scalar using the |
| 103 | /// operation kind associated with a binary AtomicRMWKind op. |
| 104 | Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, |
| 105 | Location loc, Value vector); |
| 106 | |
| 107 | /// Build the default minor identity map suitable for a vector transfer. This |
| 108 | /// also handles the case memref<... x vector<...>> -> vector<...> in which the |
| 109 | /// rank of the identity map must take the vector element type into account. |
| 110 | AffineMap getTransferMinorIdentityMap(ShapedType shapedType, |
| 111 | VectorType vectorType); |
| 112 | |
| 113 | /// Return true if the transfer_write fully writes the data accessed by the |
| 114 | /// transfer_read. |
| 115 | bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read); |
| 116 | |
| 117 | /// Return true if the write op fully over-write the priorWrite transfer_write |
| 118 | /// op. |
| 119 | bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite); |
| 120 | |
| 121 | /// Return true if we can prove that the transfer operations access disjoint |
| 122 | /// memory, without requring the accessed tensor/memref to be the same. |
| 123 | /// |
| 124 | /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values |
| 125 | /// via ValueBoundsOpInterface. |
| 126 | bool isDisjointTransferIndices(VectorTransferOpInterface transferA, |
| 127 | VectorTransferOpInterface transferB, |
| 128 | bool testDynamicValueUsingBounds = false); |
| 129 | |
| 130 | /// Return true if we can prove that the transfer operations access disjoint |
| 131 | /// memory, requiring the operations to access the same tensor/memref. |
| 132 | /// |
| 133 | /// If `testDynamicValueUsingBounds` is true, tries to test dynamic values |
| 134 | /// via ValueBoundsOpInterface. |
| 135 | bool isDisjointTransferSet(VectorTransferOpInterface transferA, |
| 136 | VectorTransferOpInterface transferB, |
| 137 | bool testDynamicValueUsingBounds = false); |
| 138 | |
| 139 | /// Returns the result value of reducing two scalar/vector values with the |
| 140 | /// corresponding arith operation. |
| 141 | Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, |
| 142 | Value v1, Value acc, |
| 143 | arith::FastMathFlagsAttr fastmath = nullptr, |
| 144 | Value mask = nullptr); |
| 145 | |
| 146 | /// Returns true if `attr` has "parallel" iterator type semantics. |
| 147 | inline bool isParallelIterator(Attribute attr) { |
| 148 | return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::parallel; |
| 149 | } |
| 150 | |
| 151 | /// Returns true if `attr` has "reduction" iterator type semantics. |
| 152 | inline bool isReductionIterator(Attribute attr) { |
| 153 | return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction; |
| 154 | } |
| 155 | |
| 156 | /// Returns the integer numbers in `values`. `values` are expected to be |
| 157 | /// constant operations. |
| 158 | SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values); |
| 159 | |
| 160 | /// Returns the integer numbers in `foldResults`. `foldResults` are expected to |
| 161 | /// be constant operations. |
| 162 | SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults); |
| 163 | |
| 164 | /// Convert `foldResults` into Values. Integer attributes are converted to |
| 165 | /// constant op. |
| 166 | SmallVector<Value> getAsValues(OpBuilder &builder, Location loc, |
| 167 | ArrayRef<OpFoldResult> foldResults); |
| 168 | |
| 169 | /// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst * |
| 170 | /// vector.vscale`), return the multiplier (`%cst`). Otherwise, return |
| 171 | /// `std::nullopt`. |
| 172 | std::optional<int64_t> getConstantVscaleMultiplier(Value value); |
| 173 | |
| 174 | //===----------------------------------------------------------------------===// |
| 175 | // Vector Masking Utilities |
| 176 | //===----------------------------------------------------------------------===// |
| 177 | |
| 178 | /// Infers the mask type for a transfer op given its vector type and |
| 179 | /// permutation map. The mask in a transfer op operation applies to the |
| 180 | /// tensor/buffer part of it and its type should match the vector shape |
| 181 | /// *before* any permutation or broadcasting. For example, |
| 182 | /// |
| 183 | /// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)> |
| 184 | /// |
| 185 | /// Has inferred mask type: |
| 186 | /// |
| 187 | /// maskType = vector<2x1xi1> |
| 188 | VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap); |
| 189 | |
| 190 | /// Create the vector.yield-ended region of a vector.mask op with `maskableOp` |
| 191 | /// as masked operation. |
| 192 | void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp); |
| 193 | |
| 194 | /// Creates a vector.mask operation around a maskable operation. Returns the |
| 195 | /// vector.mask operation if the mask provided is valid. Otherwise, returns the |
| 196 | /// maskable operation itself. |
| 197 | Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, |
| 198 | Value passthru = Value()); |
| 199 | |
| 200 | /// Creates a vector select operation that picks values from `newValue` or |
| 201 | /// `passthru` for each result vector lane based on `mask`. This utility is used |
| 202 | /// to propagate the pass-thru value for masked-out or expeculatively executed |
| 203 | /// lanes. VP intrinsics do not support pass-thru values and every mask-out lane |
| 204 | /// is set to poison. LLVM backends are usually able to match op + select |
| 205 | /// patterns and fold them into a native target instructions. |
| 206 | Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, |
| 207 | Value passthru); |
| 208 | |
| 209 | } // namespace vector |
| 210 | } // namespace mlir |
| 211 | |
| 212 | #define GET_OP_CLASSES |
| 213 | #include "mlir/Dialect/Vector/IR/VectorDialect.h.inc" |
| 214 | #include "mlir/Dialect/Vector/IR/VectorOps.h.inc" |
| 215 | |
| 216 | #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H |
| 217 | |