1//===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/OwningOpRef.h"
10#include "mlir/IR/PatternMatch.h"
11#include "mlir/Rewrite/PatternApplicator.h"
12#include "gtest/gtest.h"
13
14using namespace mlir;
15
16namespace {
17TEST(PatternBenefitTest, BenefitOrder) {
18 // There was a bug which caused low-benefit op-specific patterns to never be
19 // called in presence of high-benefit op-agnostic pattern
20
21 MLIRContext context;
22
23 OpBuilder builder(&context);
24 OwningOpRef<ModuleOp> module = ModuleOp::create(builder.getUnknownLoc());
25
26 struct Pattern1 : public OpRewritePattern<ModuleOp> {
27 Pattern1(mlir::MLIRContext *context, bool *called)
28 : OpRewritePattern<ModuleOp>(context, /*benefit*/ 1), called(called) {}
29
30 mlir::LogicalResult
31 matchAndRewrite(ModuleOp /*op*/,
32 mlir::PatternRewriter & /*rewriter*/) const override {
33 *called = true;
34 return failure();
35 }
36
37 private:
38 bool *called;
39 };
40
41 struct Pattern2 : public RewritePattern {
42 Pattern2(MLIRContext *context, bool *called)
43 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context),
44 called(called) {}
45
46 mlir::LogicalResult
47 matchAndRewrite(Operation * /*op*/,
48 mlir::PatternRewriter & /*rewriter*/) const override {
49 *called = true;
50 return failure();
51 }
52
53 private:
54 bool *called;
55 };
56
57 RewritePatternSet patterns(&context);
58
59 bool called1 = false;
60 bool called2 = false;
61
62 patterns.add<Pattern1>(arg: &context, args: &called1);
63 patterns.add<Pattern2>(arg: &context, args: &called2);
64
65 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
66 PatternApplicator pa(frozenPatterns);
67 pa.applyDefaultCostModel();
68
69 class MyPatternRewriter : public PatternRewriter {
70 public:
71 MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
72 };
73
74 MyPatternRewriter rewriter(&context);
75 (void)pa.matchAndRewrite(op: *module, rewriter);
76
77 EXPECT_TRUE(called1);
78 EXPECT_TRUE(called2);
79}
80} // namespace
81

source code of mlir/unittests/Rewrite/PatternBenefit.cpp