1//===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10
11#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
13#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/IR/TransformOps.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
21#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22#include "mlir/Dialect/X86Vector/Transforms.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::transform;
28
29//===----------------------------------------------------------------------===//
30// Apply...ConversionPatternsOp
31//===----------------------------------------------------------------------===//
32
33void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
34 TypeConverter &typeConverter, RewritePatternSet &patterns) {
35 populateVectorToLLVMConversionPatterns(
36 static_cast<LLVMTypeConverter &>(typeConverter), patterns,
37 getReassociateFpReductions(), getForce_32bitVectorIndices());
38}
39
40LogicalResult
41transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
42 transform::TypeConverterBuilderOpInterface builder) {
43 if (builder.getTypeConverterType() != "LLVMTypeConverter")
44 return emitOpError("expected LLVMTypeConverter");
45 return success();
46}
47
48//===----------------------------------------------------------------------===//
49// Apply...PatternsOp
50//===----------------------------------------------------------------------===//
51
52void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
53 RewritePatternSet &patterns) {
54 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
55}
56
57void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
58 RewritePatternSet &patterns) {
59 vector::populateFoldArithExtensionPatterns(patterns);
60}
61
62void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
63 RewritePatternSet &patterns) {
64 vector::populateVectorReductionToContractPatterns(patterns);
65}
66
67void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
68 RewritePatternSet &patterns) {
69 vector::populateVectorMaskOpLoweringPatterns(patterns);
70}
71
72void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
73 RewritePatternSet &patterns) {
74 vector::populateVectorTransferDropUnitDimsPatterns(patterns);
75}
76
77void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
78 RewritePatternSet &patterns) {
79 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
80}
81
82void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
83 RewritePatternSet &patterns) {
84 populateVectorBroadcastLoweringPatterns(patterns);
85}
86
87void transform::ApplyLowerContractionPatternsOp::populatePatterns(
88 RewritePatternSet &patterns) {
89 vector::VectorTransformsOptions vectorTransformOptions;
90 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
91 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
92 /*benefit=*/1,
93 /*disableOuterProductLowering=*/true);
94}
95
96void transform::ApplyLowerMasksPatternsOp::populatePatterns(
97 RewritePatternSet &patterns) {
98 populateVectorMaskOpLoweringPatterns(patterns);
99}
100
101void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
102 RewritePatternSet &patterns) {
103 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
104}
105
106void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
107 RewritePatternSet &patterns) {
108 populateVectorMaskMaterializationPatterns(patterns,
109 /*force32BitVectorIndices=*/false);
110}
111
112void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
113 RewritePatternSet &patterns) {
114 vector::VectorTransformsOptions vectorTransformOptions;
115 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
116 vector::populateVectorMultiReductionLoweringPatterns(
117 patterns, vectorTransformOptions.vectorMultiReductionLowering);
118}
119
120void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
121 RewritePatternSet &patterns) {
122 populateVectorOuterProductLoweringPatterns(patterns);
123}
124
125void transform::ApplyLowerGatherPatternsOp::populatePatterns(
126 RewritePatternSet &patterns) {
127 vector::populateVectorGatherLoweringPatterns(patterns);
128}
129
130void transform::ApplyLowerScanPatternsOp::populatePatterns(
131 RewritePatternSet &patterns) {
132 vector::populateVectorScanLoweringPatterns(patterns);
133}
134
135void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
136 RewritePatternSet &patterns) {
137 vector::populateVectorShapeCastLoweringPatterns(patterns);
138}
139
140void transform::ApplyLowerTransferPatternsOp::populatePatterns(
141 RewritePatternSet &patterns) {
142 vector::populateVectorTransferLoweringPatterns(patterns,
143 getMaxTransferRank());
144}
145
146void transform::ApplyLowerTransposePatternsOp::populatePatterns(
147 RewritePatternSet &patterns) {
148 vector::populateVectorTransposeLoweringPatterns(
149 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
150 getLoweringStrategy()));
151 if (getAvx2LoweringStrategy()) {
152 auto avx2LoweringOptions =
153 x86vector::avx2::LoweringOptions().setTransposeOptions(
154 x86vector::avx2::TransposeLoweringOptions()
155 .lower4x8xf32(true)
156 .lower8x8xf32(true));
157 x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
158 patterns, avx2LoweringOptions, /*benefit=*/10);
159 }
160}
161
162void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
163 RewritePatternSet &patterns) {
164 vector::populateVectorInterleaveLoweringPatterns(patterns);
165}
166
167void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
168 RewritePatternSet &patterns) {
169 populateVectorNarrowTypeRewritePatterns(patterns);
170 populateVectorTransposeNarrowTypeRewritePatterns(patterns);
171}
172
173void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
174 RewritePatternSet &patterns) {
175 vector::VectorTransformsOptions vectorTransformOptions;
176 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
177 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
178}
179
180void transform::ApplyTransferToScfPatternsOp::populatePatterns(
181 RewritePatternSet &patterns) {
182 VectorTransferToSCFOptions vectorTransferToSCFOptions =
183 VectorTransferToSCFOptions()
184 .enableFullUnroll(getFullUnroll())
185 .setTargetRank(getMaxTransferRank());
186 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
187}
188
189//===----------------------------------------------------------------------===//
190// Transform op registration
191//===----------------------------------------------------------------------===//
192
193namespace {
194/// Registers new ops and declares PDL as dependent dialect since the additional
195/// ops are using PDL types for operands and results.
196class VectorTransformDialectExtension
197 : public transform::TransformDialectExtension<
198 VectorTransformDialectExtension> {
199public:
200 VectorTransformDialectExtension() {
201 declareGeneratedDialect<vector::VectorDialect>();
202 declareGeneratedDialect<LLVM::LLVMDialect>();
203 registerTransformOps<
204#define GET_OP_LIST
205#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
206 >();
207 }
208};
209} // namespace
210
211#define GET_OP_CLASSES
212#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
213
214void mlir::vector::registerTransformDialectExtension(
215 DialectRegistry &registry) {
216 registry.addExtensions<VectorTransformDialectExtension>();
217}
218

source code of mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp