1//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic for testing Linalg transformations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Passes.h"
20#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22#include "mlir/Dialect/Linalg/Utils/Utils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Pass/PassManager.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27#include "llvm/ADT/SmallVector.h"
28
29using namespace mlir;
30using namespace mlir::linalg;
31
32namespace {
33struct TestLinalgTransforms
34 : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
36
37 TestLinalgTransforms() = default;
38 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
39
40 void getDependentDialects(DialectRegistry &registry) const override {
41 // clang-format off
42 registry.insert<affine::AffineDialect,
43 bufferization::BufferizationDialect,
44 memref::MemRefDialect,
45 scf::SCFDialect,
46 linalg::LinalgDialect,
47 vector::VectorDialect,
48 gpu::GPUDialect>();
49 // clang-format on
50 }
51 StringRef getArgument() const final {
52 return "test-linalg-transform-patterns";
53 }
54 StringRef getDescription() const final {
55 return "Test Linalg transformation patterns by applying them greedily.";
56 }
57
58 void runOnOperation() override;
59
60 Option<bool> testPatterns{*this, "test-patterns",
61 llvm::cl::desc("Test a mixed set of patterns"),
62 llvm::cl::init(Val: false)};
63 Option<bool> testVectorTransferForwardingPatterns{
64 *this, "test-vector-transfer-forwarding-patterns",
65 llvm::cl::desc(
66 "Test a fused pass that forwards memref.copy to vector.transfer"),
67 llvm::cl::init(Val: false)};
68 Option<bool> testGenericToVectorPattern{
69 *this, "test-linalg-to-vector-patterns",
70 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
71 "in vector.contract form"),
72 llvm::cl::init(Val: false)};
73 Option<bool> testGeneralizePadTensor{
74 *this, "test-generalize-pad-tensor",
75 llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
76 llvm::cl::init(Val: false)};
77 Option<bool> testGeneralizeTensorPackOp{
78 *this, "test-generalize-tensor-pack",
79 llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
80 "of tensor and Linalg ops"),
81 llvm::cl::init(Val: false)};
82 Option<bool> testGeneralizeTensorUnPackOp{
83 *this, "test-generalize-tensor-unpack",
84 llvm::cl::desc(
85 "Test transform that generalizes unpack ops into a sequence "
86 "of tensor and Linalg ops"),
87 llvm::cl::init(Val: false)};
88 Option<bool> testSwapSubTensorPadTensor{
89 *this, "test-swap-subtensor-padtensor",
90 llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
91 "tensor.pad(subtensor)"),
92 llvm::cl::init(Val: false)};
93 ListOption<int64_t> peeledLoops{
94 *this, "peeled-loops",
95 llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
96 ListOption<int64_t> tileSizes{
97 *this, "tile-sizes",
98 llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
99 Option<bool> skipPartial{
100 *this, "skip-partial",
101 llvm::cl::desc("Skip loops inside partial iterations during peeling"),
102 llvm::cl::init(Val: false)};
103 Option<std::string> loopType{
104 *this, "loop-type",
105 llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
106 "tiled_loop"),
107 llvm::cl::init(Val: "for")};
108 Option<bool> testBubbleUpExtractSliceOpPattern{
109 *this, "test-bubble-up-extract-slice-op-pattern",
110 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
111 "extract_slice + linalgOp"),
112 llvm::cl::init(Val: false)};
113 Option<bool> testSwapExtractSliceWithFill{
114 *this, "test-swap-extract-slice-with-fill-pattern",
115 llvm::cl::desc(
116 "Test patterns to swap tensor.extract_slice(linalg.fill())"),
117 llvm::cl::init(Val: false)};
118 Option<bool> testEraseUnusedOperandsAndResults{
119 *this, "test-erase-unused-operands-and-results",
120 llvm::cl::desc("Test patterns to erase unused operands and results"),
121 llvm::cl::init(Val: false)};
122 Option<bool> testEraseUnnecessaryInputs{
123 *this, "test-erase-unnecessary-inputs",
124 llvm::cl::desc("Test patterns to erase unnecessary inputs"),
125 llvm::cl::init(Val: false)};
126};
127} // namespace
128
129static void applyPatterns(func::FuncOp funcOp) {
130 MLIRContext *ctx = funcOp.getContext();
131 RewritePatternSet patterns(ctx);
132
133 //===--------------------------------------------------------------------===//
134 // Linalg distribution patterns.
135 //===--------------------------------------------------------------------===//
136 LinalgLoopDistributionOptions distributionOptions;
137
138 //===--------------------------------------------------------------------===//
139 // Linalg to vector contraction patterns.
140 //===--------------------------------------------------------------------===//
141 patterns.add<CopyVectorizationPattern>(arg&: ctx);
142
143 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
144}
145
146static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
147 RewritePatternSet forwardPattern(funcOp.getContext());
148 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
149 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
150 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
151}
152
153static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
154 RewritePatternSet patterns(funcOp.getContext());
155 auto *ctx = funcOp.getContext();
156 patterns.add<CopyVectorizationPattern>(ctx);
157 populatePadOpVectorizationPatterns(patterns);
158 populateConvolutionVectorizationPatterns(patterns);
159 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
160}
161
162static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
163 RewritePatternSet patterns(funcOp.getContext());
164 patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
165 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
166}
167
168static void applyGeneralizeTensorPackPatterns(func::FuncOp funcOp) {
169 RewritePatternSet patterns(funcOp.getContext());
170 patterns.add<GeneralizeOuterUnitDimsPackOpPattern>(funcOp.getContext());
171 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
172}
173
174static void applyGeneralizeTensorUnPackPatterns(func::FuncOp funcOp) {
175 RewritePatternSet patterns(funcOp.getContext());
176 patterns.add<GeneralizeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
177 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
178}
179
180static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
181 RewritePatternSet patterns(funcOp.getContext());
182 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
183 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
184}
185
186static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
187 RewritePatternSet patterns(funcOp.getContext());
188 populateBubbleUpExtractSliceOpPatterns(patterns);
189 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
190}
191
192static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
193 RewritePatternSet patterns(funcOp.getContext());
194 populateSwapExtractSliceWithFillPatterns(patterns);
195 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
196}
197
198static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
199 RewritePatternSet patterns(funcOp.getContext());
200 populateEraseUnusedOperandsAndResultsPatterns(patterns);
201 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
202}
203
204static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
205 RewritePatternSet patterns(funcOp.getContext());
206 populateEraseUnnecessaryInputsPatterns(patterns);
207 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
208}
209
210/// Apply transformations specified as patterns.
211void TestLinalgTransforms::runOnOperation() {
212 if (testPatterns)
213 return applyPatterns(getOperation());
214 if (testVectorTransferForwardingPatterns)
215 return applyVectorTransferForwardingPatterns(getOperation());
216 if (testGenericToVectorPattern)
217 return applyLinalgToVectorPatterns(getOperation());
218 if (testGeneralizePadTensor)
219 return applyGeneralizePadTensorPatterns(getOperation());
220 if (testGeneralizeTensorPackOp)
221 return applyGeneralizeTensorPackPatterns(getOperation());
222 if (testGeneralizeTensorUnPackOp)
223 return applyGeneralizeTensorUnPackPatterns(getOperation());
224 if (testSwapSubTensorPadTensor)
225 return applyExtractSliceOfPadTensorSwapPattern(getOperation());
226 if (testBubbleUpExtractSliceOpPattern)
227 return applyBubbleUpExtractSliceOpPattern(getOperation());
228 if (testSwapExtractSliceWithFill)
229 return applySwapExtractSliceWithFillPattern(getOperation());
230 if (testEraseUnusedOperandsAndResults)
231 return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
232 if (testEraseUnnecessaryInputs)
233 return applyEraseUnnecessaryInputs(getOperation());
234}
235
236namespace mlir {
237namespace test {
238void registerTestLinalgTransforms() {
239 PassRegistration<TestLinalgTransforms>();
240}
241} // namespace test
242} // namespace mlir
243

source code of mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp