1//===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10
11#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
13#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
19#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
20#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
21#include "mlir/Dialect/X86Vector/Transforms.h"
22
23using namespace mlir;
24using namespace mlir::vector;
25using namespace mlir::transform;
26
27//===----------------------------------------------------------------------===//
28// Apply...ConversionPatternsOp
29//===----------------------------------------------------------------------===//
30
31void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
32 TypeConverter &typeConverter, RewritePatternSet &patterns) {
33 populateVectorToLLVMConversionPatterns(
34 converter: static_cast<LLVMTypeConverter &>(typeConverter), patterns,
35 reassociateFPReductions: getReassociateFpReductions(), force32BitVectorIndices: getForce_32bitVectorIndices(),
36 useVectorAlignment: getUseVectorAlignment());
37}
38
39LogicalResult
40transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
41 transform::TypeConverterBuilderOpInterface builder) {
42 if (builder.getTypeConverterType() != "LLVMTypeConverter")
43 return emitOpError(message: "expected LLVMTypeConverter");
44 return success();
45}
46
47//===----------------------------------------------------------------------===//
48// Apply...PatternsOp
49//===----------------------------------------------------------------------===//
50
51void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
52 RewritePatternSet &patterns) {
53 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
54}
55
56void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
57 RewritePatternSet &patterns) {
58 vector::populateFoldArithExtensionPatterns(patterns);
59}
60
61void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
62 RewritePatternSet &patterns) {
63 vector::populateElementwiseToVectorOpsPatterns(patterns);
64}
65
66void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
67 RewritePatternSet &patterns) {
68 vector::populateVectorReductionToContractPatterns(patterns);
69}
70
71void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
72 RewritePatternSet &patterns) {
73 vector::populateVectorMaskOpLoweringPatterns(patterns);
74}
75
76void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
77 RewritePatternSet &patterns) {
78 vector::populateVectorTransferDropUnitDimsPatterns(patterns);
79}
80
81void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
82 RewritePatternSet &patterns) {
83 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
84}
85
86void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
87 RewritePatternSet &patterns) {
88 vector::populateDropUnitDimWithShapeCastPatterns(patterns);
89}
90
91void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
92 RewritePatternSet &patterns) {
93 vector::populateVectorBitCastLoweringPatterns(patterns);
94}
95
96void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
97 RewritePatternSet &patterns) {
98 populateVectorBroadcastLoweringPatterns(patterns);
99}
100
101void transform::ApplyLowerContractionPatternsOp::populatePatterns(
102 RewritePatternSet &patterns) {
103 populateVectorContractLoweringPatterns(patterns, vectorContractLoweringOption: getLoweringStrategy(),
104 /*benefit=*/1,
105 /*disableOuterProductLowering=*/true);
106}
107
108void transform::ApplyLowerMasksPatternsOp::populatePatterns(
109 RewritePatternSet &patterns) {
110 populateVectorMaskOpLoweringPatterns(patterns);
111}
112
113void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
114 RewritePatternSet &patterns) {
115 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
116}
117
118void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
119 RewritePatternSet &patterns) {
120 populateVectorMaskMaterializationPatterns(patterns,
121 /*force32BitVectorIndices=*/false);
122}
123
124void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
125 RewritePatternSet &patterns) {
126 vector::VectorTransformsOptions vectorTransformOptions;
127 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
128 vector::populateVectorMultiReductionLoweringPatterns(
129 patterns, options: vectorTransformOptions.vectorMultiReductionLowering);
130}
131
132void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
133 RewritePatternSet &patterns) {
134 populateVectorOuterProductLoweringPatterns(patterns);
135}
136
137void transform::ApplyLowerGatherPatternsOp::populatePatterns(
138 RewritePatternSet &patterns) {
139 vector::populateVectorGatherLoweringPatterns(patterns);
140}
141
142void transform::ApplyLowerScanPatternsOp::populatePatterns(
143 RewritePatternSet &patterns) {
144 vector::populateVectorScanLoweringPatterns(patterns);
145}
146
147void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
148 RewritePatternSet &patterns) {
149 vector::populateVectorShapeCastLoweringPatterns(patterns);
150}
151
152void transform::ApplyLowerTransferPatternsOp::populatePatterns(
153 RewritePatternSet &patterns) {
154 vector::populateVectorTransferLoweringPatterns(patterns,
155 maxTransferRank: getMaxTransferRank());
156}
157
158void transform::ApplyLowerTransposePatternsOp::populatePatterns(
159 RewritePatternSet &patterns) {
160 vector::populateVectorTransposeLoweringPatterns(patterns,
161 vectorTransposeLowering: getLoweringStrategy());
162 if (getAvx2LoweringStrategy()) {
163 auto avx2LoweringOptions =
164 x86vector::avx2::LoweringOptions().setTransposeOptions(
165 x86vector::avx2::TransposeLoweringOptions()
166 .lower4x8xf32(lower: true)
167 .lower8x8xf32(lower: true));
168 x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
169 patterns, options: avx2LoweringOptions, /*benefit=*/10);
170 }
171}
172
173void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
174 RewritePatternSet &patterns) {
175 vector::populateVectorInterleaveLoweringPatterns(patterns);
176}
177
178void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
179 RewritePatternSet &patterns) {
180 vector::populateVectorInterleaveToShufflePatterns(patterns);
181}
182
183void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
184 RewritePatternSet &patterns) {
185 populateVectorNarrowTypeRewritePatterns(patterns);
186 populateVectorTransposeNarrowTypeRewritePatterns(patterns);
187}
188
189void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
190 RewritePatternSet &patterns) {
191 vector::VectorTransformsOptions vectorTransformOptions;
192 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
193 populateVectorTransferFullPartialPatterns(patterns, options: vectorTransformOptions);
194}
195
196void transform::ApplyTransferToScfPatternsOp::populatePatterns(
197 RewritePatternSet &patterns) {
198 VectorTransferToSCFOptions vectorTransferToSCFOptions =
199 VectorTransferToSCFOptions()
200 .enableFullUnroll(u: getFullUnroll())
201 .setTargetRank(getMaxTransferRank());
202 populateVectorToSCFConversionPatterns(patterns, options: vectorTransferToSCFOptions);
203}
204
205void transform::ApplySinkVectorPatternsOp::populatePatterns(
206 RewritePatternSet &patterns) {
207 vector::populateSinkVectorOpsPatterns(patterns);
208}
209
210void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
211 RewritePatternSet &patterns) {
212 vector::populateSinkVectorMemOpsPatterns(patterns);
213}
214
215//===----------------------------------------------------------------------===//
216// Transform op registration
217//===----------------------------------------------------------------------===//
218
219namespace {
220/// Registers new ops and declares PDL as dependent dialect since the additional
221/// ops are using PDL types for operands and results.
222class VectorTransformDialectExtension
223 : public transform::TransformDialectExtension<
224 VectorTransformDialectExtension> {
225public:
226 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)
227
228 VectorTransformDialectExtension() {
229 declareGeneratedDialect<vector::VectorDialect>();
230 declareGeneratedDialect<LLVM::LLVMDialect>();
231 registerTransformOps<
232#define GET_OP_LIST
233#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
234 >();
235 }
236};
237} // namespace
238
239#define GET_OP_CLASSES
240#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
241
242void mlir::vector::registerTransformDialectExtension(
243 DialectRegistry &registry) {
244 registry.addExtensions<VectorTransformDialectExtension>();
245}
246

source code of mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp