1//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic for testing Linalg transformations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Passes.h"
20#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22#include "mlir/Dialect/Linalg/Utils/Utils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Pass/PassManager.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27#include "llvm/ADT/SmallVector.h"
28
29using namespace mlir;
30using namespace mlir::linalg;
31
32namespace {
33struct TestLinalgTransforms
34 : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
36
37 TestLinalgTransforms() = default;
38 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
39
40 void getDependentDialects(DialectRegistry &registry) const override {
41 // clang-format off
42 registry.insert<affine::AffineDialect,
43 bufferization::BufferizationDialect,
44 memref::MemRefDialect,
45 scf::SCFDialect,
46 linalg::LinalgDialect,
47 vector::VectorDialect,
48 gpu::GPUDialect>();
49 // clang-format on
50 }
51 StringRef getArgument() const final {
52 return "test-linalg-transform-patterns";
53 }
54 StringRef getDescription() const final {
55 return "Test Linalg transformation patterns by applying them greedily.";
56 }
57
58 void runOnOperation() override;
59
60 Option<bool> testPatterns{*this, "test-patterns",
61 llvm::cl::desc("Test a mixed set of patterns"),
62 llvm::cl::init(Val: false)};
63 Option<bool> testVectorTransferForwardingPatterns{
64 *this, "test-vector-transfer-forwarding-patterns",
65 llvm::cl::desc(
66 "Test a fused pass that forwards memref.copy to vector.transfer"),
67 llvm::cl::init(Val: false)};
68 Option<bool> testDecomposePadTensor{
69 *this, "test-decompose-pad-tensor",
70 llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
71 llvm::cl::init(Val: false)};
72 // TODO: This is not used - delete.
73 Option<bool> testDecomposeTensorPackOp{
74 *this, "test-decompose-linalg-pack",
75 llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
76 "of tensor and Linalg ops"),
77 llvm::cl::init(Val: false)};
78 Option<bool> testDecomposeTensorUnPackOp{
79 *this, "test-decompose-tensor-unpack",
80 llvm::cl::desc(
81 "Test transform that generalizes unpack ops into a sequence "
82 "of tensor and Linalg ops"),
83 llvm::cl::init(Val: false)};
84 Option<bool> testSwapSubTensorPadTensor{
85 *this, "test-swap-subtensor-padtensor",
86 llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
87 "tensor.pad(subtensor)"),
88 llvm::cl::init(Val: false)};
89 ListOption<int64_t> peeledLoops{
90 *this, "peeled-loops",
91 llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
92 ListOption<int64_t> tileSizes{
93 *this, "tile-sizes",
94 llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
95 Option<bool> skipPartial{
96 *this, "skip-partial",
97 llvm::cl::desc("Skip loops inside partial iterations during peeling"),
98 llvm::cl::init(Val: false)};
99 Option<std::string> loopType{
100 *this, "loop-type",
101 llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
102 "tiled_loop"),
103 llvm::cl::init(Val: "for")};
104 Option<bool> testBubbleUpExtractSliceOpPattern{
105 *this, "test-bubble-up-extract-slice-op-pattern",
106 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
107 "extract_slice + linalgOp"),
108 llvm::cl::init(Val: false)};
109 Option<bool> testSwapExtractSliceWithFill{
110 *this, "test-swap-extract-slice-with-fill-pattern",
111 llvm::cl::desc(
112 "Test patterns to swap tensor.extract_slice(linalg.fill())"),
113 llvm::cl::init(Val: false)};
114 Option<bool> testEraseUnusedOperandsAndResults{
115 *this, "test-erase-unused-operands-and-results",
116 llvm::cl::desc("Test patterns to erase unused operands and results"),
117 llvm::cl::init(Val: false)};
118 Option<bool> testEraseUnnecessaryInputs{
119 *this, "test-erase-unnecessary-inputs",
120 llvm::cl::desc("Test patterns to erase unnecessary inputs"),
121 llvm::cl::init(Val: false)};
122 Option<bool> testWinogradConv2D{
123 *this, "test-winograd-conv2d",
124 llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
125 llvm::cl::init(Val: false)};
126 Option<bool> testDecomposeWinogradOps{
127 *this, "test-decompose-winograd-ops",
128 llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(Val: false)};
129 Option<bool> testFoldIntoPackAndUnpack{
130 *this, "test-fold-into-pack-and-unpack",
131 llvm::cl::desc("Test folding ops into linalg.pack and linalg.unpack"),
132 llvm::cl::init(Val: false)};
133 Option<bool> testFoldIntoPackAndUnpackWithControlFn{
134 *this, "test-fold-into-pack-and-unpack-control",
135 llvm::cl::desc(
136 "Test controlling folding ops into linalg.pack and linalg.unpack"),
137 llvm::cl::init(Val: false)};
138 Option<bool> testSimplifyPackUnpackPatterns{
139 *this, "test-simplify-pack-unpack-patterns",
140 llvm::cl::desc("Test patterns to simplify linalg.pack and linalg.unpack"),
141 llvm::cl::init(Val: false)};
142};
143} // namespace
144
145static void applyPatterns(func::FuncOp funcOp) {
146 MLIRContext *ctx = funcOp.getContext();
147 RewritePatternSet patterns(ctx);
148
149 //===--------------------------------------------------------------------===//
150 // Linalg distribution patterns.
151 //===--------------------------------------------------------------------===//
152 LinalgLoopDistributionOptions distributionOptions;
153
154 //===--------------------------------------------------------------------===//
155 // Linalg to vector contraction patterns.
156 //===--------------------------------------------------------------------===//
157 patterns.add<CopyVectorizationPattern>(arg&: ctx);
158
159 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
160}
161
162static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
163 RewritePatternSet forwardPattern(funcOp.getContext());
164 forwardPattern.add<LinalgCopyVTRForwardingPattern>(arg: funcOp.getContext());
165 forwardPattern.add<LinalgCopyVTWForwardingPattern>(arg: funcOp.getContext());
166 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(forwardPattern));
167}
168
169static void applyDecomposePadPatterns(func::FuncOp funcOp) {
170 RewritePatternSet patterns(funcOp.getContext());
171 patterns.add<DecomposePadOpPattern>(arg: funcOp.getContext());
172 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
173}
174
175static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) {
176 RewritePatternSet patterns(funcOp.getContext());
177 patterns.add<DecomposeOuterUnitDimsPackOpPattern>(arg: funcOp.getContext());
178 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
179}
180
181static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) {
182 RewritePatternSet patterns(funcOp.getContext());
183 patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(arg: funcOp.getContext());
184 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
185}
186
187static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
188 RewritePatternSet patterns(funcOp.getContext());
189 patterns.add<ExtractSliceOfPadTensorSwapPattern>(arg: funcOp.getContext());
190 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
191}
192
193static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
194 RewritePatternSet patterns(funcOp.getContext());
195 populateBubbleUpExtractSliceOpPatterns(patterns);
196 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
197}
198
199static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
200 RewritePatternSet patterns(funcOp.getContext());
201 populateSwapExtractSliceWithFillPatterns(patterns);
202 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
203}
204
205static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
206 RewritePatternSet patterns(funcOp.getContext());
207 populateEraseUnusedOperandsAndResultsPatterns(patterns);
208 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
209}
210
211static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
212 RewritePatternSet patterns(funcOp.getContext());
213 populateEraseUnnecessaryInputsPatterns(patterns);
214 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
215}
216
217static void applyWinogradConv2D(func::FuncOp funcOp) {
218 RewritePatternSet patterns(funcOp.getContext());
219 populateWinogradConv2DPatterns(patterns, fmr: WinogradConv2DFmr::F_4_3);
220 populateWinogradConv2DPatterns(patterns, fmr: WinogradConv2DFmr::F_2_5);
221 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
222}
223
224static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
225 RewritePatternSet patterns(funcOp.getContext());
226 populateDecomposeWinogradOpsPatterns(patterns);
227 (void)applyPatternsGreedily(op: funcOp, patterns: std::move(patterns));
228}
229
230static void applyFoldIntoPackAndUnpackPatterns(
231 Operation *rootOp,
232 linalg::ControlFoldIntoPackUnpackFn controlFn = nullptr) {
233 RewritePatternSet patterns(rootOp->getContext());
234 linalg::populateFoldIntoPackAndUnpackPatterns(patterns, controlFn);
235 (void)applyPatternsGreedily(op: rootOp, patterns: std::move(patterns));
236}
237
238static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
239 RewritePatternSet patterns(rootOp->getContext());
240 linalg::populateSimplifyPackAndUnpackPatterns(patterns);
241 (void)applyPatternsGreedily(op: rootOp, patterns: std::move(patterns));
242}
243
244/// Apply transformations specified as patterns.
245void TestLinalgTransforms::runOnOperation() {
246 if (testPatterns)
247 return applyPatterns(funcOp: getOperation());
248 if (testVectorTransferForwardingPatterns)
249 return applyVectorTransferForwardingPatterns(funcOp: getOperation());
250 if (testDecomposePadTensor)
251 return applyDecomposePadPatterns(funcOp: getOperation());
252 if (testDecomposeTensorPackOp)
253 return applyDecomposeTensorPackPatterns(funcOp: getOperation());
254 if (testDecomposeTensorUnPackOp)
255 return applyDecomposeTensorUnPackPatterns(funcOp: getOperation());
256 if (testSwapSubTensorPadTensor)
257 return applyExtractSliceOfPadTensorSwapPattern(funcOp: getOperation());
258 if (testBubbleUpExtractSliceOpPattern)
259 return applyBubbleUpExtractSliceOpPattern(funcOp: getOperation());
260 if (testSwapExtractSliceWithFill)
261 return applySwapExtractSliceWithFillPattern(funcOp: getOperation());
262 if (testEraseUnusedOperandsAndResults)
263 return applyEraseUnusedOperandsAndResultsPatterns(funcOp: getOperation());
264 if (testEraseUnnecessaryInputs)
265 return applyEraseUnnecessaryInputs(funcOp: getOperation());
266 if (testWinogradConv2D)
267 return applyWinogradConv2D(funcOp: getOperation());
268 if (testDecomposeWinogradOps)
269 return applyDecomposeWinogradOps(funcOp: getOperation());
270 Operation *rootOp = getOperation();
271 if (testFoldIntoPackAndUnpack)
272 applyFoldIntoPackAndUnpackPatterns(rootOp);
273 if (testFoldIntoPackAndUnpackWithControlFn) {
274 linalg::ControlFoldIntoPackUnpackFn controlFn = [](OpOperand *opOperand) {
275 Operation *producer = opOperand->get().getDefiningOp();
276 Operation *consumer = opOperand->getOwner();
277 // If we have a pack/unpack consumer and a producer that has multiple
278 // uses, do not apply the folding patterns.
279 if (isa<linalg::PackOp, linalg::UnPackOp>(Val: consumer) &&
280 isa<TilingInterface>(Val: producer) && !producer->hasOneUse())
281 return false;
282 return true;
283 };
284 applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn);
285 }
286 if (testSimplifyPackUnpackPatterns)
287 applySimplifyPackUnpackPatterns(rootOp);
288}
289
290namespace mlir {
291namespace test {
292void registerTestLinalgTransforms() {
293 PassRegistration<TestLinalgTransforms>();
294}
295} // namespace test
296} // namespace mlir
297

source code of mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp