1//===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10
11#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
13#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/IR/TransformOps.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
21#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22#include "mlir/Dialect/X86Vector/Transforms.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::transform;
28
29//===----------------------------------------------------------------------===//
30// Apply...ConversionPatternsOp
31//===----------------------------------------------------------------------===//
32
33void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
34 TypeConverter &typeConverter, RewritePatternSet &patterns) {
35 populateVectorToLLVMConversionPatterns(
36 static_cast<LLVMTypeConverter &>(typeConverter), patterns,
37 getReassociateFpReductions(), getForce_32bitVectorIndices(),
38 getUseVectorAlignment());
39}
40
41LogicalResult
42transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
43 transform::TypeConverterBuilderOpInterface builder) {
44 if (builder.getTypeConverterType() != "LLVMTypeConverter")
45 return emitOpError("expected LLVMTypeConverter");
46 return success();
47}
48
49//===----------------------------------------------------------------------===//
50// Apply...PatternsOp
51//===----------------------------------------------------------------------===//
52
53void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
54 RewritePatternSet &patterns) {
55 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
56}
57
58void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
59 RewritePatternSet &patterns) {
60 vector::populateFoldArithExtensionPatterns(patterns);
61}
62
63void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
64 RewritePatternSet &patterns) {
65 vector::populateElementwiseToVectorOpsPatterns(patterns);
66}
67
68void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
69 RewritePatternSet &patterns) {
70 vector::populateVectorReductionToContractPatterns(patterns);
71}
72
73void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
74 RewritePatternSet &patterns) {
75 vector::populateVectorMaskOpLoweringPatterns(patterns);
76}
77
78void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
79 RewritePatternSet &patterns) {
80 vector::populateVectorTransferDropUnitDimsPatterns(patterns);
81}
82
83void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
84 RewritePatternSet &patterns) {
85 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
86}
87
88void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
89 RewritePatternSet &patterns) {
90 vector::populateDropUnitDimWithShapeCastPatterns(patterns);
91}
92
93void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
94 RewritePatternSet &patterns) {
95 vector::populateVectorBitCastLoweringPatterns(patterns);
96}
97
98void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
99 RewritePatternSet &patterns) {
100 populateVectorBroadcastLoweringPatterns(patterns);
101}
102
103void transform::ApplyLowerContractionPatternsOp::populatePatterns(
104 RewritePatternSet &patterns) {
105 populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
106 /*benefit=*/1,
107 /*disableOuterProductLowering=*/true);
108}
109
110void transform::ApplyLowerMasksPatternsOp::populatePatterns(
111 RewritePatternSet &patterns) {
112 populateVectorMaskOpLoweringPatterns(patterns);
113}
114
115void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
116 RewritePatternSet &patterns) {
117 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
118}
119
120void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
121 RewritePatternSet &patterns) {
122 populateVectorMaskMaterializationPatterns(patterns,
123 /*force32BitVectorIndices=*/false);
124}
125
126void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
127 RewritePatternSet &patterns) {
128 vector::VectorTransformsOptions vectorTransformOptions;
129 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
130 vector::populateVectorMultiReductionLoweringPatterns(
131 patterns, vectorTransformOptions.vectorMultiReductionLowering);
132}
133
134void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
135 RewritePatternSet &patterns) {
136 populateVectorOuterProductLoweringPatterns(patterns);
137}
138
139void transform::ApplyLowerGatherPatternsOp::populatePatterns(
140 RewritePatternSet &patterns) {
141 vector::populateVectorGatherLoweringPatterns(patterns);
142}
143
144void transform::ApplyLowerScanPatternsOp::populatePatterns(
145 RewritePatternSet &patterns) {
146 vector::populateVectorScanLoweringPatterns(patterns);
147}
148
149void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
150 RewritePatternSet &patterns) {
151 vector::populateVectorShapeCastLoweringPatterns(patterns);
152}
153
154void transform::ApplyLowerTransferPatternsOp::populatePatterns(
155 RewritePatternSet &patterns) {
156 vector::populateVectorTransferLoweringPatterns(patterns,
157 getMaxTransferRank());
158}
159
160void transform::ApplyLowerTransposePatternsOp::populatePatterns(
161 RewritePatternSet &patterns) {
162 vector::populateVectorTransposeLoweringPatterns(patterns,
163 getLoweringStrategy());
164 if (getAvx2LoweringStrategy()) {
165 auto avx2LoweringOptions =
166 x86vector::avx2::LoweringOptions().setTransposeOptions(
167 x86vector::avx2::TransposeLoweringOptions()
168 .lower4x8xf32(true)
169 .lower8x8xf32(true));
170 x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
171 patterns, avx2LoweringOptions, /*benefit=*/10);
172 }
173}
174
175void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
176 RewritePatternSet &patterns) {
177 vector::populateVectorInterleaveLoweringPatterns(patterns);
178}
179
180void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
181 RewritePatternSet &patterns) {
182 vector::populateVectorInterleaveToShufflePatterns(patterns);
183}
184
185void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
186 RewritePatternSet &patterns) {
187 populateVectorNarrowTypeRewritePatterns(patterns);
188 populateVectorTransposeNarrowTypeRewritePatterns(patterns);
189}
190
191void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
192 RewritePatternSet &patterns) {
193 vector::VectorTransformsOptions vectorTransformOptions;
194 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
195 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
196}
197
198void transform::ApplyTransferToScfPatternsOp::populatePatterns(
199 RewritePatternSet &patterns) {
200 VectorTransferToSCFOptions vectorTransferToSCFOptions =
201 VectorTransferToSCFOptions()
202 .enableFullUnroll(getFullUnroll())
203 .setTargetRank(getMaxTransferRank());
204 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
205}
206
207void transform::ApplySinkVectorPatternsOp::populatePatterns(
208 RewritePatternSet &patterns) {
209 vector::populateSinkVectorOpsPatterns(patterns);
210}
211
212void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
213 RewritePatternSet &patterns) {
214 vector::populateSinkVectorMemOpsPatterns(patterns);
215}
216
217//===----------------------------------------------------------------------===//
218// Transform op registration
219//===----------------------------------------------------------------------===//
220
221namespace {
222/// Registers new ops and declares PDL as dependent dialect since the additional
223/// ops are using PDL types for operands and results.
224class VectorTransformDialectExtension
225 : public transform::TransformDialectExtension<
226 VectorTransformDialectExtension> {
227public:
228 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)
229
230 VectorTransformDialectExtension() {
231 declareGeneratedDialect<vector::VectorDialect>();
232 declareGeneratedDialect<LLVM::LLVMDialect>();
233 registerTransformOps<
234#define GET_OP_LIST
235#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
236 >();
237 }
238};
239} // namespace
240
241#define GET_OP_CLASSES
242#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
243
244void mlir::vector::registerTransformDialectExtension(
245 DialectRegistry &registry) {
246 registry.addExtensions<VectorTransformDialectExtension>();
247}
248

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp