1//===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Vector dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
14#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
15
16#include "mlir/Bytecode/BytecodeOpInterface.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
19#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/Dialect.h"
24#include "mlir/IR/OpDefinition.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Interfaces/ControlFlowInterfaces.h"
27#include "mlir/Interfaces/DestinationStyleOpInterface.h"
28#include "mlir/Interfaces/InferTypeOpInterface.h"
29#include "mlir/Interfaces/SideEffectInterfaces.h"
30#include "mlir/Interfaces/VectorInterfaces.h"
31#include "mlir/Interfaces/ViewLikeInterface.h"
32#include "llvm/ADT/SetVector.h"
33#include "llvm/ADT/StringExtras.h"
34
35// Pull in all enum type definitions and utility function declarations.
36#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"
37
38#define GET_ATTRDEF_CLASSES
39#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc"
40
41namespace mlir {
42class MLIRContext;
43class RewritePatternSet;
44
45namespace arith {
46enum class AtomicRMWKind : uint64_t;
47} // namespace arith
48
49namespace vector {
50class ContractionOp;
51class TransferReadOp;
52class TransferWriteOp;
53class VectorDialect;
54
55namespace detail {
56struct BitmaskEnumStorage;
57} // namespace detail
58
59/// Predefined constant_mask kinds.
60enum class ConstantMaskKind { AllFalse = 0, AllTrue };
61
62/// Default callback to build a region with a 'vector.yield' terminator with no
63/// arguments.
64void buildTerminatedBody(OpBuilder &builder, Location loc);
65
66/// Return whether `srcType` can be broadcast to `dstVectorType` under the
67/// semantics of the `vector.broadcast` op.
68enum class BroadcastableToResult {
69 Success = 0,
70 SourceRankHigher = 1,
71 DimensionMismatch = 2,
72 SourceTypeNotAVector = 3
73};
74
75struct VectorDim {
76 int64_t dim;
77 bool isScalable;
78};
79BroadcastableToResult
80isBroadcastableTo(Type srcType, VectorType dstVectorType,
81 std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);
82
83/// Collect a set of vector-to-vector canonicalization patterns.
84void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
85 PatternBenefit benefit = 1);
86
87/// Collect a set of patterns that fold arithmetic extension on floating point
88/// into vector contract for the backends with native support.
89void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
90
91/// Collect a set of patterns that fold elementwise op on vectors to the vector
92/// dialect.
93void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
94
95/// Returns the integer type required for subscripts in the vector dialect.
96IntegerType getVectorSubscriptType(Builder &builder);
97
98/// Returns an integer array attribute containing the given values using
99/// the integer type required for subscripts in the vector dialect.
100ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
101
102/// Returns the value obtained by reducing the vector into a scalar using the
103/// operation kind associated with a binary AtomicRMWKind op.
104Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
105 Location loc, Value vector);
106
107/// Build the default minor identity map suitable for a vector transfer. This
108/// also handles the case memref<... x vector<...>> -> vector<...> in which the
109/// rank of the identity map must take the vector element type into account.
110AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
111 VectorType vectorType);
112
113/// Return true if the transfer_write fully writes the data accessed by the
114/// transfer_read.
115bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
116
117/// Return true if the write op fully over-write the priorWrite transfer_write
118/// op.
119bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
120
121/// Return true if we can prove that the transfer operations access disjoint
122/// memory, without requring the accessed tensor/memref to be the same.
123///
124/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
125/// via ValueBoundsOpInterface.
126bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
127 VectorTransferOpInterface transferB,
128 bool testDynamicValueUsingBounds = false);
129
130/// Return true if we can prove that the transfer operations access disjoint
131/// memory, requiring the operations to access the same tensor/memref.
132///
133/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
134/// via ValueBoundsOpInterface.
135bool isDisjointTransferSet(VectorTransferOpInterface transferA,
136 VectorTransferOpInterface transferB,
137 bool testDynamicValueUsingBounds = false);
138
139/// Returns the result value of reducing two scalar/vector values with the
140/// corresponding arith operation.
141Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
142 Value v1, Value acc,
143 arith::FastMathFlagsAttr fastmath = nullptr,
144 Value mask = nullptr);
145
146/// Returns true if `attr` has "parallel" iterator type semantics.
147inline bool isParallelIterator(Attribute attr) {
148 return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::parallel;
149}
150
151/// Returns true if `attr` has "reduction" iterator type semantics.
152inline bool isReductionIterator(Attribute attr) {
153 return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction;
154}
155
156/// Returns the integer numbers in `values`. `values` are expected to be
157/// constant operations.
158SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);
159
160/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
161/// be constant operations.
162SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);
163
164/// Convert `foldResults` into Values. Integer attributes are converted to
165/// constant op.
166SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
167 ArrayRef<OpFoldResult> foldResults);
168
169/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
170/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
171/// `std::nullopt`.
172std::optional<int64_t> getConstantVscaleMultiplier(Value value);
173
174//===----------------------------------------------------------------------===//
175// Vector Masking Utilities
176//===----------------------------------------------------------------------===//
177
178/// Infers the mask type for a transfer op given its vector type and
179/// permutation map. The mask in a transfer op operation applies to the
180/// tensor/buffer part of it and its type should match the vector shape
181/// *before* any permutation or broadcasting. For example,
182///
183/// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)>
184///
185/// Has inferred mask type:
186///
187/// maskType = vector<2x1xi1>
188VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
189
190/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
191/// as masked operation.
192void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
193
194/// Creates a vector.mask operation around a maskable operation. Returns the
195/// vector.mask operation if the mask provided is valid. Otherwise, returns the
196/// maskable operation itself.
197Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask,
198 Value passthru = Value());
199
200/// Creates a vector select operation that picks values from `newValue` or
201/// `passthru` for each result vector lane based on `mask`. This utility is used
202/// to propagate the pass-thru value for masked-out or expeculatively executed
203/// lanes. VP intrinsics do not support pass-thru values and every mask-out lane
204/// is set to poison. LLVM backends are usually able to match op + select
205/// patterns and fold them into a native target instructions.
206Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
207 Value passthru);
208
209} // namespace vector
210} // namespace mlir
211
212#define GET_OP_CLASSES
213#include "mlir/Dialect/Vector/IR/VectorDialect.h.inc"
214#include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
215
216#endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
217

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/include/mlir/Dialect/Vector/IR/VectorOps.h