1//===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic for testing Linalg fusion patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16#include "mlir/Dialect/SCF/Transforms/Patterns.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Pass/PassManager.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include "mlir/Transforms/Passes.h"
21
22using namespace mlir;
23using namespace mlir::linalg;
24
25static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
26 OpBuilder b(f);
27 DenseSet<Operation *> eraseSet;
28
29 // Save original Linalg ops, we only want to make a pass over those.
30 SmallVector<LinalgOp, 8> linalgOps;
31 f.walk([&](LinalgOp op) {
32 // TODO: support multi-results.
33 if (op->getNumResults() <= 1)
34 linalgOps.push_back(op);
35 });
36
37 // Tile and Fuse for tensors inputs (TODO: all tensor operands).
38 bool changed = false;
39 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
40 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
41 if (isa<MemRefType>(opOperand.get().getType()))
42 continue;
43 if (isa<RankedTensorType>(opOperand.get().getType())) {
44 // Tile and Fuse tensor input.
45 if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs())
46 continue;
47 auto info = fuseProducerOfTensor(b, opOperand);
48 if (failed(info))
49 continue;
50 auto *originalOp = info->originalProducer.getOperation();
51 auto *originalOpInLinalgOpsVector =
52 std::find(linalgOps.begin(), linalgOps.end(), originalOp);
53 *originalOpInLinalgOpsVector = info->fusedProducer;
54 // Don't mark for erasure in the tensor case, let DCE handle this.
55 changed = true;
56 }
57 }
58 }
59 // The `fuseProducerOfBuffer` function performs structural checks and in
60 // particular that no covering read or write exist between the consumer and
61 // the producer. As a consequence, the only fusions that may occur preserve
62 // subsequent dependences and are guaranteed by construction to produce the
63 // whole view. We may thus erase the producer once it is fused.
64 for (auto *e : eraseSet)
65 e->erase();
66
67 return changed ? success() : failure();
68}
69
70namespace {
71struct TestLinalgGreedyFusion
72 : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> {
73 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion)
74
75 void getDependentDialects(DialectRegistry &registry) const override {
76 registry.insert<affine::AffineDialect, linalg::LinalgDialect,
77 memref::MemRefDialect, scf::SCFDialect>();
78 }
79 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
80 StringRef getDescription() const final {
81 return "Test Linalg fusion by applying a greedy test transformation.";
82 }
83 void runOnOperation() override {
84 MLIRContext *context = &getContext();
85 RewritePatternSet patterns =
86 linalg::getLinalgTilingCanonicalizationPatterns(ctx: context);
87 patterns.add<ExtractSliceOfPadTensorSwapPattern>(arg&: context);
88 scf::populateSCFForLoopCanonicalizationPatterns(patterns);
89 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
90 OpPassManager pm(func::FuncOp::getOperationName());
91 pm.addPass(createLoopInvariantCodeMotionPass());
92 pm.addPass(createCanonicalizerPass());
93 pm.addPass(createCSEPass());
94 do {
95 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
96 if (failed(runPipeline(pipeline&: pm, op: getOperation())))
97 this->signalPassFailure();
98 } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
99 }
100};
101} // namespace
102
103namespace mlir {
104namespace test {
105void registerTestLinalgGreedyFusion() {
106 PassRegistration<TestLinalgGreedyFusion>();
107}
108
109} // namespace test
110} // namespace mlir
111

source code of mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp