1 | //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements logic for testing Linalg transformations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Linalg/Passes.h" |
20 | #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
21 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
23 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
24 | #include "mlir/Pass/PassManager.h" |
25 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
26 | |
27 | #include "llvm/ADT/SmallVector.h" |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::linalg; |
31 | |
32 | namespace { |
33 | struct TestLinalgTransforms |
34 | : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> { |
35 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms) |
36 | |
37 | TestLinalgTransforms() = default; |
38 | TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} |
39 | |
40 | void getDependentDialects(DialectRegistry ®istry) const override { |
41 | // clang-format off |
42 | registry.insert<affine::AffineDialect, |
43 | bufferization::BufferizationDialect, |
44 | memref::MemRefDialect, |
45 | scf::SCFDialect, |
46 | linalg::LinalgDialect, |
47 | vector::VectorDialect, |
48 | gpu::GPUDialect>(); |
49 | // clang-format on |
50 | } |
51 | StringRef getArgument() const final { |
52 | return "test-linalg-transform-patterns" ; |
53 | } |
54 | StringRef getDescription() const final { |
55 | return "Test Linalg transformation patterns by applying them greedily." ; |
56 | } |
57 | |
58 | void runOnOperation() override; |
59 | |
60 | Option<bool> testPatterns{*this, "test-patterns" , |
61 | llvm::cl::desc("Test a mixed set of patterns" ), |
62 | llvm::cl::init(Val: false)}; |
63 | Option<bool> testVectorTransferForwardingPatterns{ |
64 | *this, "test-vector-transfer-forwarding-patterns" , |
65 | llvm::cl::desc( |
66 | "Test a fused pass that forwards memref.copy to vector.transfer" ), |
67 | llvm::cl::init(Val: false)}; |
68 | Option<bool> testGenericToVectorPattern{ |
69 | *this, "test-linalg-to-vector-patterns" , |
70 | llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " |
71 | "in vector.contract form" ), |
72 | llvm::cl::init(Val: false)}; |
73 | Option<bool> testGeneralizePadTensor{ |
74 | *this, "test-generalize-pad-tensor" , |
75 | llvm::cl::desc("Test transform pad tensor by copying with generic ops" ), |
76 | llvm::cl::init(Val: false)}; |
77 | Option<bool> testGeneralizeTensorPackOp{ |
78 | *this, "test-generalize-tensor-pack" , |
79 | llvm::cl::desc("Test transform that generalizes pack ops into a sequence " |
80 | "of tensor and Linalg ops" ), |
81 | llvm::cl::init(Val: false)}; |
82 | Option<bool> testGeneralizeTensorUnPackOp{ |
83 | *this, "test-generalize-tensor-unpack" , |
84 | llvm::cl::desc( |
85 | "Test transform that generalizes unpack ops into a sequence " |
86 | "of tensor and Linalg ops" ), |
87 | llvm::cl::init(Val: false)}; |
88 | Option<bool> testSwapSubTensorPadTensor{ |
89 | *this, "test-swap-subtensor-padtensor" , |
90 | llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " |
91 | "tensor.pad(subtensor)" ), |
92 | llvm::cl::init(Val: false)}; |
93 | ListOption<int64_t> peeledLoops{ |
94 | *this, "peeled-loops" , |
95 | llvm::cl::desc("Loops to be peeled when test-tile-pattern" )}; |
96 | ListOption<int64_t> tileSizes{ |
97 | *this, "tile-sizes" , |
98 | llvm::cl::desc("Linalg tile sizes for test-tile-pattern" )}; |
99 | Option<bool> skipPartial{ |
100 | *this, "skip-partial" , |
101 | llvm::cl::desc("Skip loops inside partial iterations during peeling" ), |
102 | llvm::cl::init(Val: false)}; |
103 | Option<std::string> loopType{ |
104 | *this, "loop-type" , |
105 | llvm::cl::desc("Specify the type of loops to generate: for, parallel or " |
106 | "tiled_loop" ), |
107 | llvm::cl::init(Val: "for" )}; |
108 | Option<bool> { |
109 | *this, "test-bubble-up-extract-slice-op-pattern" , |
110 | llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " |
111 | "extract_slice + linalgOp" ), |
112 | llvm::cl::init(Val: false)}; |
113 | Option<bool> { |
114 | *this, "test-swap-extract-slice-with-fill-pattern" , |
115 | llvm::cl::desc( |
116 | "Test patterns to swap tensor.extract_slice(linalg.fill())" ), |
117 | llvm::cl::init(Val: false)}; |
118 | Option<bool> testEraseUnusedOperandsAndResults{ |
119 | *this, "test-erase-unused-operands-and-results" , |
120 | llvm::cl::desc("Test patterns to erase unused operands and results" ), |
121 | llvm::cl::init(Val: false)}; |
122 | Option<bool> testEraseUnnecessaryInputs{ |
123 | *this, "test-erase-unnecessary-inputs" , |
124 | llvm::cl::desc("Test patterns to erase unnecessary inputs" ), |
125 | llvm::cl::init(Val: false)}; |
126 | }; |
127 | } // namespace |
128 | |
129 | static void applyPatterns(func::FuncOp funcOp) { |
130 | MLIRContext *ctx = funcOp.getContext(); |
131 | RewritePatternSet patterns(ctx); |
132 | |
133 | //===--------------------------------------------------------------------===// |
134 | // Linalg distribution patterns. |
135 | //===--------------------------------------------------------------------===// |
136 | LinalgLoopDistributionOptions distributionOptions; |
137 | |
138 | //===--------------------------------------------------------------------===// |
139 | // Linalg to vector contraction patterns. |
140 | //===--------------------------------------------------------------------===// |
141 | patterns.add<CopyVectorizationPattern>(arg&: ctx); |
142 | |
143 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
144 | } |
145 | |
146 | static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { |
147 | RewritePatternSet forwardPattern(funcOp.getContext()); |
148 | forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); |
149 | forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); |
150 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); |
151 | } |
152 | |
153 | static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { |
154 | RewritePatternSet patterns(funcOp.getContext()); |
155 | auto *ctx = funcOp.getContext(); |
156 | patterns.add<CopyVectorizationPattern>(ctx); |
157 | populatePadOpVectorizationPatterns(patterns); |
158 | populateConvolutionVectorizationPatterns(patterns); |
159 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
160 | } |
161 | |
162 | static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) { |
163 | RewritePatternSet patterns(funcOp.getContext()); |
164 | patterns.add<GeneralizePadOpPattern>(funcOp.getContext()); |
165 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
166 | } |
167 | |
168 | static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) { |
169 | RewritePatternSet patterns(funcOp.getContext()); |
170 | patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(funcOp.getContext()); |
171 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
172 | } |
173 | |
174 | static void applyGeneralizeTensorUnPackPatterns(func::FuncOp funcOp) { |
175 | RewritePatternSet patterns(funcOp.getContext()); |
176 | patterns.add<GeneralizeOuterUnitDimsUnPackOpPattern>(funcOp.getContext()); |
177 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
178 | } |
179 | |
180 | static void (func::FuncOp funcOp) { |
181 | RewritePatternSet patterns(funcOp.getContext()); |
182 | patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); |
183 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
184 | } |
185 | |
186 | static void (func::FuncOp funcOp) { |
187 | RewritePatternSet patterns(funcOp.getContext()); |
188 | populateBubbleUpExtractSliceOpPatterns(patterns); |
189 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
190 | } |
191 | |
192 | static void (func::FuncOp funcOp) { |
193 | RewritePatternSet patterns(funcOp.getContext()); |
194 | populateSwapExtractSliceWithFillPatterns(patterns); |
195 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
196 | } |
197 | |
198 | static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) { |
199 | RewritePatternSet patterns(funcOp.getContext()); |
200 | populateEraseUnusedOperandsAndResultsPatterns(patterns); |
201 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
202 | } |
203 | |
204 | static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { |
205 | RewritePatternSet patterns(funcOp.getContext()); |
206 | populateEraseUnnecessaryInputsPatterns(patterns); |
207 | (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
208 | } |
209 | |
210 | /// Apply transformations specified as patterns. |
211 | void TestLinalgTransforms::runOnOperation() { |
212 | if (testPatterns) |
213 | return applyPatterns(getOperation()); |
214 | if (testVectorTransferForwardingPatterns) |
215 | return applyVectorTransferForwardingPatterns(getOperation()); |
216 | if (testGenericToVectorPattern) |
217 | return applyLinalgToVectorPatterns(getOperation()); |
218 | if (testGeneralizePadTensor) |
219 | return applyGeneralizePadTensorPatterns(getOperation()); |
220 | if (testGeneralizeTensorPackOp) |
221 | return applyGeneralizeTensorPackPatterns(getOperation()); |
222 | if (testGeneralizeTensorUnPackOp) |
223 | return applyGeneralizeTensorUnPackPatterns(getOperation()); |
224 | if (testSwapSubTensorPadTensor) |
225 | return applyExtractSliceOfPadTensorSwapPattern(getOperation()); |
226 | if (testBubbleUpExtractSliceOpPattern) |
227 | return applyBubbleUpExtractSliceOpPattern(getOperation()); |
228 | if (testSwapExtractSliceWithFill) |
229 | return applySwapExtractSliceWithFillPattern(getOperation()); |
230 | if (testEraseUnusedOperandsAndResults) |
231 | return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); |
232 | if (testEraseUnnecessaryInputs) |
233 | return applyEraseUnnecessaryInputs(getOperation()); |
234 | } |
235 | |
236 | namespace mlir { |
237 | namespace test { |
238 | void registerTestLinalgTransforms() { |
239 | PassRegistration<TestLinalgTransforms>(); |
240 | } |
241 | } // namespace test |
242 | } // namespace mlir |
243 | |