1 | //===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H |
10 | #define MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H |
11 | |
12 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
13 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
14 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
15 | #include "mlir/IR/OpImplementation.h" |
16 | |
17 | namespace mlir { |
18 | namespace vector { |
19 | class VectorOp; |
20 | struct LowerVectorsOptions; |
21 | } // namespace vector |
22 | } // namespace mlir |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // Vector Transform Operations |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | #define GET_OP_CLASSES |
29 | #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h.inc" |
30 | |
31 | namespace mlir { |
32 | class DialectRegistry; |
33 | |
34 | namespace vector { |
35 | void registerTransformDialectExtension(DialectRegistry ®istry); |
36 | |
37 | /// Helper structure used to hold the different options of LowerVectorsOp. |
38 | struct LowerVectorsOptions : public VectorTransformsOptions { |
39 | // Have the default values match the LowerVectorsOp values in the td file. |
40 | LowerVectorsOptions() : VectorTransformsOptions() { |
41 | setVectorTransformsOptions(VectorContractLowering::OuterProduct); |
42 | setVectorMultiReductionLowering( |
43 | VectorMultiReductionLowering::InnerParallel); |
44 | setVectorTransposeLowering(VectorTransposeLowering::EltWise); |
45 | setVectorTransferSplit(VectorTransferSplit::LinalgCopy); |
46 | } |
47 | |
48 | /// Duplicate the base API of VectorTransformsOptions but return the |
49 | /// LowerVectorsOptions type. This allows to really set up the different |
50 | /// options in any order via chained setXXX calls. @{ |
51 | LowerVectorsOptions &setVectorTransformsOptions(VectorContractLowering opt) { |
52 | VectorTransformsOptions::setVectorTransformsOptions(opt); |
53 | return *this; |
54 | } |
55 | |
56 | LowerVectorsOptions & |
57 | setVectorMultiReductionLowering(VectorMultiReductionLowering opt) { |
58 | VectorTransformsOptions::setVectorMultiReductionLowering(opt); |
59 | return *this; |
60 | } |
61 | LowerVectorsOptions &setVectorTransposeLowering(VectorTransposeLowering opt) { |
62 | VectorTransformsOptions::setVectorTransposeLowering(opt); |
63 | return *this; |
64 | } |
65 | LowerVectorsOptions &setVectorTransferSplit(VectorTransferSplit opt) { |
66 | VectorTransformsOptions::setVectorTransferSplit(opt); |
67 | return *this; |
68 | } |
69 | /// @} |
70 | |
71 | bool transposeAVX2Lowering = false; |
72 | LowerVectorsOptions &setTransposeAVX2Lowering(bool opt) { |
73 | transposeAVX2Lowering = opt; |
74 | return *this; |
75 | } |
76 | |
77 | bool unrollVectorTransfers = true; |
78 | LowerVectorsOptions &setUnrollVectorTransfers(bool opt) { |
79 | unrollVectorTransfers = opt; |
80 | return *this; |
81 | } |
82 | }; |
83 | } // namespace vector |
84 | } // namespace mlir |
85 | |
86 | #endif // MLIR_DIALECT_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H |
87 | |