1//===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/IR/BuiltinOps.h"
11#include "mlir/IR/Matchers.h"
12#include "mlir/Interfaces/FunctionInterfaces.h"
13#include "mlir/Pass/Pass.h"
14
15using namespace mlir;
16
17namespace {
18/// This is a test pass for verifying matchers.
19struct TestMatchers
20 : public PassWrapper<TestMatchers, InterfacePass<FunctionOpInterface>> {
21 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchers)
22
23 void runOnOperation() override;
24 StringRef getArgument() const final { return "test-matchers"; }
25 StringRef getDescription() const final {
26 return "Test C++ pattern matchers.";
27 }
28};
29} // namespace
30
31// This could be done better but is not worth the variadic template trouble.
32template <typename Matcher>
33static unsigned countMatches(FunctionOpInterface f, Matcher &matcher) {
34 unsigned count = 0;
35 f.walk([&count, &matcher](Operation *op) {
36 if (matcher.match(op))
37 ++count;
38 });
39 return count;
40}
41
42using mlir::matchers::m_Any;
43using mlir::matchers::m_Val;
44static void test1(FunctionOpInterface f) {
45 assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
46
47 auto a = m_Val(f.getArgument(0));
48 auto b = m_Val(f.getArgument(1));
49 auto c = m_Val(f.getArgument(2));
50
51 auto p0 = m_Op<arith::AddFOp>(); // using 0-arity matcher
52 llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
53 << " times\n";
54
55 auto p1 = m_Op<arith::MulFOp>(); // using 0-arity matcher
56 llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
57 << " times\n";
58
59 auto p2 = m_Op<arith::AddFOp>(m_Op<arith::AddFOp>(), m_Any());
60 llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
61 << " times\n";
62
63 auto p3 = m_Op<arith::AddFOp>(m_Any(), m_Op<arith::AddFOp>());
64 llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
65 << " times\n";
66
67 auto p4 = m_Op<arith::MulFOp>(m_Op<arith::AddFOp>(), m_Any());
68 llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
69 << " times\n";
70
71 auto p5 = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
72 llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
73 << " times\n";
74
75 auto p6 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Any());
76 llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
77 << " times\n";
78
79 auto p7 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
80 llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
81 << " times\n";
82
83 auto mulOfMulmul =
84 m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::MulFOp>());
85 auto p8 = m_Op<arith::MulFOp>(mulOfMulmul, mulOfMulmul);
86 llvm::outs()
87 << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
88 << countMatches(f, p8) << " times\n";
89
90 // clang-format off
91 auto mulOfMuladd = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(), m_Op<arith::AddFOp>());
92 auto mulOfAnyadd = m_Op<arith::MulFOp>(m_Any(), m_Op<arith::AddFOp>());
93 auto p9 = m_Op<arith::MulFOp>(m_Op<arith::MulFOp>(
94 mulOfMuladd, m_Op<arith::MulFOp>()),
95 m_Op<arith::MulFOp>(mulOfAnyadd, mulOfAnyadd));
96 // clang-format on
97 llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
98 "add(*)), mul(*, add(*)))) matched "
99 << countMatches(f, p9) << " times\n";
100
101 auto p10 = m_Op<arith::AddFOp>(a, b);
102 llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
103 << " times\n";
104
105 auto p11 = m_Op<arith::AddFOp>(a, c);
106 llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
107 << " times\n";
108
109 auto p12 = m_Op<arith::AddFOp>(b, a);
110 llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
111 << " times\n";
112
113 auto p13 = m_Op<arith::AddFOp>(c, a);
114 llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
115 << " times\n";
116
117 auto p14 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(c, b));
118 llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
119 << " times\n";
120
121 auto p15 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(b, c));
122 llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
123 << " times\n";
124
125 auto mulOfAany = m_Op<arith::MulFOp>(a, m_Any());
126 auto p16 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(a, c));
127 llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
128 << countMatches(f, p16) << " times\n";
129
130 auto p17 = m_Op<arith::MulFOp>(mulOfAany, m_Op<arith::AddFOp>(c, b));
131 llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
132 << countMatches(f, p17) << " times\n";
133}
134
135void test2(FunctionOpInterface f) {
136 auto a = m_Val(f.getArgument(0));
137 FloatAttr floatAttr;
138 auto p =
139 m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant(&floatAttr)));
140 auto p1 = m_Op<arith::MulFOp>(a, m_Op<arith::AddFOp>(a, m_Constant()));
141 // Last operation that is not the terminator.
142 Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
143 if (p.match(lastOp))
144 llvm::outs()
145 << "Pattern add(add(a, constant), a) matched and bound constant to: "
146 << floatAttr.getValueAsDouble() << "\n";
147 if (p1.match(lastOp))
148 llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
149}
150
151void test3(FunctionOpInterface f) {
152 arith::FastMathFlagsAttr fastMathAttr;
153 auto p = m_Op<arith::MulFOp>(m_Any(),
154 m_Op<arith::AddFOp>(m_Any(), m_Op("test.name")));
155 auto p1 = m_Attr("fastmath", &fastMathAttr);
156
157 // Last operation that is not the terminator.
158 Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
159 if (p.match(lastOp))
160 llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n";
161 if (p1.match(lastOp))
162 llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: "
163 << fastMathAttr.getValue() << "\n";
164}
165
166void TestMatchers::runOnOperation() {
167 auto f = getOperation();
168 llvm::outs() << f.getName() << "\n";
169 if (f.getName() == "test1")
170 test1(f);
171 if (f.getName() == "test2")
172 test2(f);
173 if (f.getName() == "test3")
174 test3(f);
175}
176
177namespace mlir {
178void registerTestMatchers() { PassRegistration<TestMatchers>(); }
179} // namespace mlir
180

source code of mlir/test/lib/IR/TestMatchers.cpp