1//===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/PatternMatch.h"
10#include "gtest/gtest.h"
11
12#include "../../test/lib/Dialect/Test/TestDialect.h"
13#include "../../test/lib/Dialect/Test/TestOps.h"
14
15using namespace mlir;
16
17namespace {
18struct AnOpRewritePattern : OpRewritePattern<test::OpA> {
19 AnOpRewritePattern(MLIRContext *context)
20 : OpRewritePattern(context, /*benefit=*/1,
21 /*generatedNames=*/{test::OpB::getOperationName()}) {}
22};
23TEST(OpRewritePatternTest, GetGeneratedNames) {
24 MLIRContext context;
25 AnOpRewritePattern pattern(&context);
26 ArrayRef<OperationName> ops = pattern.getGeneratedOps();
27
28 ASSERT_EQ(ops.size(), 1u);
29 ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName());
30}
31} // end anonymous namespace
32
33namespace {
34LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) {
35 return failure();
36}
37TEST(AnOpRewritePatternTest, PatternFuncAttributes) {
38 MLIRContext context;
39 RewritePatternSet patterns(&context);
40
41 patterns.add(anOpRewritePatternFunc, /*benefit=*/3,
42 /*generatedNames=*/{test::OpB::getOperationName()});
43 ASSERT_EQ(patterns.getNativePatterns().size(), 1U);
44 auto &pattern = patterns.getNativePatterns().front();
45 ASSERT_EQ(pattern->getBenefit(), 3);
46 ASSERT_EQ(pattern->getGeneratedOps().size(), 1U);
47 ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(),
48 test::OpB::getOperationName());
49}
50} // end anonymous namespace
51

source code of mlir/unittests/IR/PatternMatchTest.cpp