1 | //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements logic for testing Linalg transformations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Linalg/Passes.h" |
20 | #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
21 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
23 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
24 | #include "mlir/Pass/PassManager.h" |
25 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
26 | |
27 | #include "llvm/ADT/SmallVector.h" |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::linalg; |
31 | |
32 | namespace { |
33 | struct TestLinalgTransforms |
34 | : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> { |
35 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms) |
36 | |
37 | TestLinalgTransforms() = default; |
38 | TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} |
39 | |
40 | void getDependentDialects(DialectRegistry ®istry) const override { |
41 | // clang-format off |
42 | registry.insert<affine::AffineDialect, |
43 | bufferization::BufferizationDialect, |
44 | memref::MemRefDialect, |
45 | scf::SCFDialect, |
46 | linalg::LinalgDialect, |
47 | vector::VectorDialect, |
48 | gpu::GPUDialect>(); |
49 | // clang-format on |
50 | } |
51 | StringRef getArgument() const final { |
52 | return "test-linalg-transform-patterns" ; |
53 | } |
54 | StringRef getDescription() const final { |
55 | return "Test Linalg transformation patterns by applying them greedily." ; |
56 | } |
57 | |
58 | void runOnOperation() override; |
59 | |
60 | Option<bool> testPatterns{*this, "test-patterns" , |
61 | llvm::cl::desc("Test a mixed set of patterns" ), |
62 | llvm::cl::init(Val: false)}; |
63 | Option<bool> testVectorTransferForwardingPatterns{ |
64 | *this, "test-vector-transfer-forwarding-patterns" , |
65 | llvm::cl::desc( |
66 | "Test a fused pass that forwards memref.copy to vector.transfer" ), |
67 | llvm::cl::init(Val: false)}; |
68 | Option<bool> testGenericToVectorPattern{ |
69 | *this, "test-linalg-to-vector-patterns" , |
70 | llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " |
71 | "in vector.contract form" ), |
72 | llvm::cl::init(Val: false)}; |
73 | Option<bool> testDecomposePadTensor{ |
74 | *this, "test-decompose-pad-tensor" , |
75 | llvm::cl::desc("Test transform pad tensor by copying with generic ops" ), |
76 | llvm::cl::init(Val: false)}; |
77 | // TODO: This is not used - delete. |
78 | Option<bool> testDecomposeTensorPackOp{ |
79 | *this, "test-decompose-linalg-pack" , |
80 | llvm::cl::desc("Test transform that generalizes pack ops into a sequence " |
81 | "of tensor and Linalg ops" ), |
82 | llvm::cl::init(Val: false)}; |
83 | Option<bool> testDecomposeTensorUnPackOp{ |
84 | *this, "test-decompose-tensor-unpack" , |
85 | llvm::cl::desc( |
86 | "Test transform that generalizes unpack ops into a sequence " |
87 | "of tensor and Linalg ops" ), |
88 | llvm::cl::init(Val: false)}; |
89 | Option<bool> testSwapSubTensorPadTensor{ |
90 | *this, "test-swap-subtensor-padtensor" , |
91 | llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " |
92 | "tensor.pad(subtensor)" ), |
93 | llvm::cl::init(Val: false)}; |
94 | ListOption<int64_t> peeledLoops{ |
95 | *this, "peeled-loops" , |
96 | llvm::cl::desc("Loops to be peeled when test-tile-pattern" )}; |
97 | ListOption<int64_t> tileSizes{ |
98 | *this, "tile-sizes" , |
99 | llvm::cl::desc("Linalg tile sizes for test-tile-pattern" )}; |
100 | Option<bool> skipPartial{ |
101 | *this, "skip-partial" , |
102 | llvm::cl::desc("Skip loops inside partial iterations during peeling" ), |
103 | llvm::cl::init(Val: false)}; |
104 | Option<std::string> loopType{ |
105 | *this, "loop-type" , |
106 | llvm::cl::desc("Specify the type of loops to generate: for, parallel or " |
107 | "tiled_loop" ), |
108 | llvm::cl::init(Val: "for" )}; |
109 | Option<bool> { |
110 | *this, "test-bubble-up-extract-slice-op-pattern" , |
111 | llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " |
112 | "extract_slice + linalgOp" ), |
113 | llvm::cl::init(Val: false)}; |
114 | Option<bool> { |
115 | *this, "test-swap-extract-slice-with-fill-pattern" , |
116 | llvm::cl::desc( |
117 | "Test patterns to swap tensor.extract_slice(linalg.fill())" ), |
118 | llvm::cl::init(Val: false)}; |
119 | Option<bool> testEraseUnusedOperandsAndResults{ |
120 | *this, "test-erase-unused-operands-and-results" , |
121 | llvm::cl::desc("Test patterns to erase unused operands and results" ), |
122 | llvm::cl::init(Val: false)}; |
123 | Option<bool> testEraseUnnecessaryInputs{ |
124 | *this, "test-erase-unnecessary-inputs" , |
125 | llvm::cl::desc("Test patterns to erase unnecessary inputs" ), |
126 | llvm::cl::init(Val: false)}; |
127 | Option<bool> testWinogradConv2D{ |
128 | *this, "test-winograd-conv2d" , |
129 | llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm" ), |
130 | llvm::cl::init(Val: false)}; |
131 | Option<bool> testDecomposeWinogradOps{ |
132 | *this, "test-decompose-winograd-ops" , |
133 | llvm::cl::desc("Test decompose Winograd ops" ), llvm::cl::init(Val: false)}; |
134 | Option<bool> testFoldIntoPackAndUnpack{ |
135 | *this, "test-fold-into-pack-and-unpack" , |
136 | llvm::cl::desc("Test folding ops into linalg.pack and linalg.unpack" ), |
137 | llvm::cl::init(Val: false)}; |
138 | Option<bool> testSimplifyPackUnpackPatterns{ |
139 | *this, "test-simplify-pack-unpack-patterns" , |
140 | llvm::cl::desc("Test patterns to simplify linalg.pack and linalg.unpack" ), |
141 | llvm::cl::init(Val: false)}; |
142 | }; |
143 | } // namespace |
144 | |
145 | static void applyPatterns(func::FuncOp funcOp) { |
146 | MLIRContext *ctx = funcOp.getContext(); |
147 | RewritePatternSet patterns(ctx); |
148 | |
149 | //===--------------------------------------------------------------------===// |
150 | // Linalg distribution patterns. |
151 | //===--------------------------------------------------------------------===// |
152 | LinalgLoopDistributionOptions distributionOptions; |
153 | |
154 | //===--------------------------------------------------------------------===// |
155 | // Linalg to vector contraction patterns. |
156 | //===--------------------------------------------------------------------===// |
157 | patterns.add<CopyVectorizationPattern>(arg&: ctx); |
158 | |
159 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
160 | } |
161 | |
162 | static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { |
163 | RewritePatternSet forwardPattern(funcOp.getContext()); |
164 | forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); |
165 | forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); |
166 | (void)applyPatternsGreedily(funcOp, std::move(forwardPattern)); |
167 | } |
168 | |
169 | static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { |
170 | RewritePatternSet patterns(funcOp.getContext()); |
171 | auto *ctx = funcOp.getContext(); |
172 | patterns.add<CopyVectorizationPattern>(ctx); |
173 | populatePadOpVectorizationPatterns(patterns); |
174 | populateConvolutionVectorizationPatterns(patterns); |
175 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
176 | } |
177 | |
178 | static void applyDecomposePadPatterns(func::FuncOp funcOp) { |
179 | RewritePatternSet patterns(funcOp.getContext()); |
180 | patterns.add<DecomposePadOpPattern>(funcOp.getContext()); |
181 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
182 | } |
183 | |
184 | static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) { |
185 | RewritePatternSet patterns(funcOp.getContext()); |
186 | patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext()); |
187 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
188 | } |
189 | |
190 | static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) { |
191 | RewritePatternSet patterns(funcOp.getContext()); |
192 | patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext()); |
193 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
194 | } |
195 | |
196 | static void (func::FuncOp funcOp) { |
197 | RewritePatternSet patterns(funcOp.getContext()); |
198 | patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); |
199 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
200 | } |
201 | |
202 | static void (func::FuncOp funcOp) { |
203 | RewritePatternSet patterns(funcOp.getContext()); |
204 | populateBubbleUpExtractSliceOpPatterns(patterns); |
205 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
206 | } |
207 | |
208 | static void (func::FuncOp funcOp) { |
209 | RewritePatternSet patterns(funcOp.getContext()); |
210 | populateSwapExtractSliceWithFillPatterns(patterns); |
211 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
212 | } |
213 | |
214 | static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) { |
215 | RewritePatternSet patterns(funcOp.getContext()); |
216 | populateEraseUnusedOperandsAndResultsPatterns(patterns); |
217 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
218 | } |
219 | |
220 | static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { |
221 | RewritePatternSet patterns(funcOp.getContext()); |
222 | populateEraseUnnecessaryInputsPatterns(patterns); |
223 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
224 | } |
225 | |
226 | static void applyWinogradConv2D(func::FuncOp funcOp) { |
227 | RewritePatternSet patterns(funcOp.getContext()); |
228 | populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); |
229 | populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); |
230 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
231 | } |
232 | |
233 | static void applyDecomposeWinogradOps(func::FuncOp funcOp) { |
234 | RewritePatternSet patterns(funcOp.getContext()); |
235 | populateDecomposeWinogradOpsPatterns(patterns); |
236 | (void)applyPatternsGreedily(funcOp, std::move(patterns)); |
237 | } |
238 | |
239 | static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { |
240 | RewritePatternSet patterns(rootOp->getContext()); |
241 | linalg::populateFoldIntoPackAndUnpackPatterns(patterns); |
242 | (void)applyPatternsGreedily(rootOp, std::move(patterns)); |
243 | } |
244 | |
245 | static void applySimplifyPackUnpackPatterns(Operation *rootOp) { |
246 | RewritePatternSet patterns(rootOp->getContext()); |
247 | linalg::populateSimplifyPackAndUnpackPatterns(patterns); |
248 | (void)applyPatternsGreedily(rootOp, std::move(patterns)); |
249 | } |
250 | |
251 | /// Apply transformations specified as patterns. |
252 | void TestLinalgTransforms::runOnOperation() { |
253 | if (testPatterns) |
254 | return applyPatterns(getOperation()); |
255 | if (testVectorTransferForwardingPatterns) |
256 | return applyVectorTransferForwardingPatterns(getOperation()); |
257 | if (testGenericToVectorPattern) |
258 | return applyLinalgToVectorPatterns(getOperation()); |
259 | if (testDecomposePadTensor) |
260 | return applyDecomposePadPatterns(getOperation()); |
261 | if (testDecomposeTensorPackOp) |
262 | return applyDecomposeTensorPackPatterns(getOperation()); |
263 | if (testDecomposeTensorUnPackOp) |
264 | return applyDecomposeTensorUnPackPatterns(getOperation()); |
265 | if (testSwapSubTensorPadTensor) |
266 | return applyExtractSliceOfPadTensorSwapPattern(getOperation()); |
267 | if (testBubbleUpExtractSliceOpPattern) |
268 | return applyBubbleUpExtractSliceOpPattern(getOperation()); |
269 | if (testSwapExtractSliceWithFill) |
270 | return applySwapExtractSliceWithFillPattern(getOperation()); |
271 | if (testEraseUnusedOperandsAndResults) |
272 | return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); |
273 | if (testEraseUnnecessaryInputs) |
274 | return applyEraseUnnecessaryInputs(getOperation()); |
275 | if (testWinogradConv2D) |
276 | return applyWinogradConv2D(getOperation()); |
277 | if (testDecomposeWinogradOps) |
278 | return applyDecomposeWinogradOps(getOperation()); |
279 | Operation *rootOp = getOperation(); |
280 | if (testFoldIntoPackAndUnpack) |
281 | applyFoldIntoPackAndUnpackPatterns(rootOp); |
282 | if (testSimplifyPackUnpackPatterns) |
283 | applySimplifyPackUnpackPatterns(rootOp); |
284 | } |
285 | |
286 | namespace mlir { |
287 | namespace test { |
288 | void registerTestLinalgTransforms() { |
289 | PassRegistration<TestLinalgTransforms>(); |
290 | } |
291 | } // namespace test |
292 | } // namespace mlir |
293 | |