| 1 | //===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/IR/PatternMatch.h" |
| 10 | #include "gtest/gtest.h" |
| 11 | |
| 12 | #include "../../test/lib/Dialect/Test/TestDialect.h" |
| 13 | #include "../../test/lib/Dialect/Test/TestOps.h" |
| 14 | |
| 15 | using namespace mlir; |
| 16 | |
| 17 | namespace { |
| 18 | struct AnOpRewritePattern : OpRewritePattern<test::OpA> { |
| 19 | AnOpRewritePattern(MLIRContext *context) |
| 20 | : OpRewritePattern(context, /*benefit=*/1, |
| 21 | /*generatedNames=*/{test::OpB::getOperationName()}) {} |
| 22 | |
| 23 | LogicalResult matchAndRewrite(test::OpA op, |
| 24 | PatternRewriter &rewriter) const override { |
| 25 | return failure(); |
| 26 | } |
| 27 | }; |
| 28 | TEST(OpRewritePatternTest, GetGeneratedNames) { |
| 29 | MLIRContext context; |
| 30 | AnOpRewritePattern pattern(&context); |
| 31 | ArrayRef<OperationName> ops = pattern.getGeneratedOps(); |
| 32 | |
| 33 | ASSERT_EQ(ops.size(), 1u); |
| 34 | ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName()); |
| 35 | } |
| 36 | } // end anonymous namespace |
| 37 | |
| 38 | namespace { |
| 39 | LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) { |
| 40 | return failure(); |
| 41 | } |
| 42 | TEST(AnOpRewritePatternTest, PatternFuncAttributes) { |
| 43 | MLIRContext context; |
| 44 | RewritePatternSet patterns(&context); |
| 45 | |
| 46 | patterns.add(anOpRewritePatternFunc, /*benefit=*/3, |
| 47 | /*generatedNames=*/{test::OpB::getOperationName()}); |
| 48 | ASSERT_EQ(patterns.getNativePatterns().size(), 1U); |
| 49 | auto &pattern = patterns.getNativePatterns().front(); |
| 50 | ASSERT_EQ(pattern->getBenefit(), 3); |
| 51 | ASSERT_EQ(pattern->getGeneratedOps().size(), 1U); |
| 52 | ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(), |
| 53 | test::OpB::getOperationName()); |
| 54 | } |
| 55 | } // end anonymous namespace |
| 56 | |