1//===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/PatternMatch.h"
10#include "gtest/gtest.h"
11
12#include "../../test/lib/Dialect/Test/TestDialect.h"
13
14using namespace mlir;
15
16namespace {
17struct AnOpRewritePattern : OpRewritePattern<test::OpA> {
18 AnOpRewritePattern(MLIRContext *context)
19 : OpRewritePattern(context, /*benefit=*/1,
20 /*generatedNames=*/{test::OpB::getOperationName()}) {}
21};
22TEST(OpRewritePatternTest, GetGeneratedNames) {
23 MLIRContext context;
24 AnOpRewritePattern pattern(&context);
25 ArrayRef<OperationName> ops = pattern.getGeneratedOps();
26
27 ASSERT_EQ(ops.size(), 1u);
28 ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName());
29}
30} // end anonymous namespace
31
32namespace {
33LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) {
34 return failure();
35}
36TEST(AnOpRewritePatternTest, PatternFuncAttributes) {
37 MLIRContext context;
38 RewritePatternSet patterns(&context);
39
40 patterns.add(anOpRewritePatternFunc, /*benefit=*/3,
41 /*generatedNames=*/{test::OpB::getOperationName()});
42 ASSERT_EQ(patterns.getNativePatterns().size(), 1U);
43 auto &pattern = patterns.getNativePatterns().front();
44 ASSERT_EQ(pattern->getBenefit(), 3);
45 ASSERT_EQ(pattern->getGeneratedOps().size(), 1U);
46 ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(),
47 test::OpB::getOperationName());
48}
49} // end anonymous namespace
50

source code of mlir/unittests/IR/PatternMatchTest.cpp