1 | //===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/OwningOpRef.h" |
10 | #include "mlir/IR/PatternMatch.h" |
11 | #include "mlir/Rewrite/PatternApplicator.h" |
12 | #include "gtest/gtest.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace { |
17 | TEST(PatternBenefitTest, BenefitOrder) { |
18 | // There was a bug which caused low-benefit op-specific patterns to never be |
19 | // called in presence of high-benefit op-agnostic pattern |
20 | |
21 | MLIRContext context; |
22 | |
23 | OpBuilder builder(&context); |
24 | OwningOpRef<ModuleOp> module = ModuleOp::create(builder.getUnknownLoc()); |
25 | |
26 | struct Pattern1 : public OpRewritePattern<ModuleOp> { |
27 | Pattern1(mlir::MLIRContext *context, bool *called) |
28 | : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {} |
29 | |
30 | mlir::LogicalResult |
31 | matchAndRewrite(ModuleOp /*op*/, |
32 | mlir::PatternRewriter & /*rewriter*/) const override { |
33 | *called = true; |
34 | return failure(); |
35 | } |
36 | |
37 | private: |
38 | bool *called; |
39 | }; |
40 | |
41 | struct Pattern2 : public RewritePattern { |
42 | Pattern2(MLIRContext *context, bool *called) |
43 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context), |
44 | called(called) {} |
45 | |
46 | mlir::LogicalResult |
47 | matchAndRewrite(Operation * /*op*/, |
48 | mlir::PatternRewriter & /*rewriter*/) const override { |
49 | *called = true; |
50 | return failure(); |
51 | } |
52 | |
53 | private: |
54 | bool *called; |
55 | }; |
56 | |
57 | RewritePatternSet patterns(&context); |
58 | |
59 | bool called1 = false; |
60 | bool called2 = false; |
61 | |
62 | patterns.add<Pattern1>(arg: &context, args: &called1); |
63 | patterns.add<Pattern2>(arg: &context, args: &called2); |
64 | |
65 | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
66 | PatternApplicator pa(frozenPatterns); |
67 | pa.applyDefaultCostModel(); |
68 | |
69 | class MyPatternRewriter : public PatternRewriter { |
70 | public: |
71 | MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {} |
72 | }; |
73 | |
74 | MyPatternRewriter rewriter(&context); |
75 | (void)pa.matchAndRewrite(op: *module, rewriter); |
76 | |
77 | EXPECT_TRUE(called1); |
78 | EXPECT_TRUE(called2); |
79 | } |
80 | } // namespace |
81 | |