1//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass for testing fusion of elementwise operations in
10// Linalg, mainly linalg options.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Pass/PassManager.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "llvm/ADT/TypeSwitch.h"
21
22using namespace mlir;
23
24static void addOperands(Operation *op, SetVector<Value> &operandSet) {
25 if (!op)
26 return;
27 TypeSwitch<Operation *, void>(op)
28 .Case<linalg::LinalgOp>(caseFn: [&](linalg::LinalgOp linalgOp) {
29 SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
30 operandSet.insert(Start: inputOperands.begin(), End: inputOperands.end());
31 })
32 .Default(defaultFn: [&](Operation *operation) {
33 operandSet.insert(Start: operation->operand_begin(), End: operation->operand_end());
34 });
35}
36
37template <int limit = 3>
38static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
39 Operation *producer = fusedOperand->get().getDefiningOp();
40 if (!producer)
41 return false;
42
43 Operation *consumer = fusedOperand->getOwner();
44 SetVector<Value> fusedOpOperands;
45 if (producer->getNumResults() != 1)
46 return false;
47 addOperands(op: consumer, operandSet&: fusedOpOperands);
48 fusedOpOperands.remove(X: producer->getResult(idx: 0));
49 addOperands(op: producer, operandSet&: fusedOpOperands);
50 return fusedOpOperands.size() <= limit;
51}
52
53namespace {
54
55/// Pattern to test fusion of producer with consumer, even if producer has
56/// multiple uses.
57struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
58 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
59
60 LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
61 PatternRewriter &rewriter) const override {
62 OpOperand *fusableOperand = nullptr;
63 for (OpOperand &operand : genericOp->getOpOperands()) {
64 if (linalg::areElementwiseOpsFusable(&operand)) {
65 fusableOperand = &operand;
66 break;
67 }
68 }
69 if (!fusableOperand) {
70 return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
71 }
72 std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
73 linalg::fuseElementwiseOps(rewriter, fusedOperand: fusableOperand);
74 if (!fusionResult)
75 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
76 for (auto [origValue, replacement] : fusionResult->replacements) {
77 rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) {
78 return use.getOwner() != genericOp.getOperation();
79 });
80 }
81 rewriter.eraseOp(op: genericOp);
82 return success();
83 }
84};
85
86struct TestLinalgElementwiseFusion
87 : public PassWrapper<TestLinalgElementwiseFusion,
88 OperationPass<func::FuncOp>> {
89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
90
91 TestLinalgElementwiseFusion() = default;
92 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
93 : PassWrapper(pass) {}
94 void getDependentDialects(DialectRegistry &registry) const override {
95 registry.insert<affine::AffineDialect, linalg::LinalgDialect,
96 memref::MemRefDialect, tensor::TensorDialect>();
97 }
98 StringRef getArgument() const final {
99 return "test-linalg-elementwise-fusion-patterns";
100 }
101 StringRef getDescription() const final {
102 return "Test Linalg element wise operation fusion patterns";
103 }
104
105 Option<bool> fuseGenericOps{
106 *this, "fuse-generic-ops",
107 llvm::cl::desc("Test fusion of generic operations."),
108 llvm::cl::init(Val: false)};
109
110 Option<bool> fuseGenericOpsControl{
111 *this, "fuse-generic-ops-control",
112 llvm::cl::desc(
113 "Test fusion of generic operations with a control function."),
114 llvm::cl::init(Val: false)};
115
116 Option<bool> fuseWithReshapeByExpansion{
117 *this, "fuse-with-reshape-by-expansion",
118 llvm::cl::desc(
119 "Test fusion of generic operations with reshape by expansion"),
120 llvm::cl::init(Val: false)};
121
122 Option<bool> controlFuseByExpansion{
123 *this, "control-fusion-by-expansion",
124 llvm::cl::desc(
125 "Test controlling fusion of reshape with generic op by expansion"),
126 llvm::cl::init(Val: false)};
127
128 Option<bool> fuseWithReshapeByCollapsing{
129 *this, "fuse-with-reshape-by-collapsing",
130 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
131 "collapse the iteration space of the consumer"),
132 llvm::cl::init(Val: false)};
133
134 Option<bool> fuseWithReshapeByCollapsingWithControlFn{
135 *this, "fuse-with-reshape-by-collapsing-control",
136 llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
137 "fusion patterns that "
138 "collapse the iteration space of the consumer"),
139 llvm::cl::init(Val: false)};
140
141 Option<bool> fuseMultiUseProducer{
142 *this, "fuse-multiuse-producer",
143 llvm::cl::desc("Test fusion of producer ops with multiple uses"),
144 llvm::cl::init(Val: false)};
145
146 ListOption<int64_t> collapseDimensions{
147 *this, "collapse-dimensions-control",
148 llvm::cl::desc("Test controlling dimension collapse pattern")};
149
150 void runOnOperation() override {
151 MLIRContext *context = &this->getContext();
152 func::FuncOp funcOp = this->getOperation();
153
154 if (fuseGenericOps) {
155 RewritePatternSet fusionPatterns(context);
156 auto controlFn = [](OpOperand *operand) { return true; };
157 linalg::populateElementwiseOpsFusionPatterns(patterns&: fusionPatterns, controlElementwiseOpFusion: controlFn);
158 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
159 std::move(fusionPatterns))))
160 return signalPassFailure();
161 return;
162 }
163
164 if (fuseGenericOpsControl) {
165 RewritePatternSet fusionPatterns(context);
166 linalg::populateElementwiseOpsFusionPatterns(patterns&: fusionPatterns,
167 controlElementwiseOpFusion: setFusedOpOperandLimit<4>);
168
169 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
170 std::move(fusionPatterns))))
171 return signalPassFailure();
172 return;
173 }
174
175 if (fuseWithReshapeByExpansion) {
176 RewritePatternSet fusionPatterns(context);
177 linalg::populateFoldReshapeOpsByExpansionPatterns(
178 patterns&: fusionPatterns, controlFoldingReshapes: [](OpOperand * /*fusedOperand*/) { return true; });
179 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
180 std::move(fusionPatterns))))
181 return signalPassFailure();
182 return;
183 }
184
185 if (controlFuseByExpansion) {
186 RewritePatternSet fusionPatterns(context);
187
188 linalg::ControlFusionFn controlReshapeFusionFn =
189 [](OpOperand *fusedOperand) {
190 Operation *producer = fusedOperand->get().getDefiningOp();
191 if (!producer)
192 return false;
193
194 if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) {
195 if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
196 return false;
197 }
198 }
199
200 Operation *consumer = fusedOperand->getOwner();
201 if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
202 if (expandOp->hasOneUse()) {
203 OpOperand &use = *expandOp->getUses().begin();
204 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
205 if (linalgOp && linalgOp.isDpsInit(&use))
206 return true;
207 }
208 return false;
209 }
210 return true;
211 };
212
213 linalg::populateFoldReshapeOpsByExpansionPatterns(patterns&: fusionPatterns,
214 controlFoldingReshapes: controlReshapeFusionFn);
215 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
216 std::move(fusionPatterns))))
217 return signalPassFailure();
218 return;
219 }
220
221 if (fuseWithReshapeByCollapsing) {
222 RewritePatternSet patterns(context);
223 linalg::populateFoldReshapeOpsByCollapsingPatterns(
224 patterns, controlFoldingReshapes: [](OpOperand * /*fusedOperand */) { return true; });
225 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
226 std::move(patterns))))
227 return signalPassFailure();
228 return;
229 }
230
231 if (fuseWithReshapeByCollapsingWithControlFn) {
232 RewritePatternSet patterns(context);
233 linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool {
234 Operation *producer = fusedOperand->get().getDefiningOp();
235 if (isa<tensor::ExpandShapeOp>(producer)) {
236 // Skip fusing the first operand.
237 return fusedOperand->getOperandNumber();
238 }
239 return true;
240 };
241 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFoldingReshapes: controlFn);
242 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
243 std::move(patterns))))
244 return signalPassFailure();
245 return;
246 }
247
248 if (fuseMultiUseProducer) {
249 RewritePatternSet patterns(context);
250 patterns.insert<TestMultiUseProducerFusion>(arg&: context);
251 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
252 std::move(patterns))))
253 return signalPassFailure();
254 return;
255 }
256
257 if (!collapseDimensions.empty()) {
258 SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
259 collapseDimensions.end());
260 linalg::GetCollapsableDimensionsFn collapseFn =
261 [&dims](linalg::LinalgOp op) {
262 SmallVector<ReassociationIndices> reassociations;
263 reassociations.emplace_back(Args&: dims);
264 return reassociations;
265 };
266 RewritePatternSet patterns(context);
267 linalg::populateCollapseDimensions(patterns, controlCollapseDimensions: collapseFn);
268 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
269 std::move(patterns))))
270 return signalPassFailure();
271 return;
272 }
273 }
274};
275
276} // namespace
277
278namespace mlir {
279namespace test {
280void registerTestLinalgElementwiseFusion() {
281 PassRegistration<TestLinalgElementwiseFusion>();
282}
283} // namespace test
284} // namespace mlir
285

source code of mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp