1//===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/PatternMatch.h"
10#include "gtest/gtest.h"
11
12#include "../../test/lib/Dialect/Test/TestDialect.h"
13#include "../../test/lib/Dialect/Test/TestOps.h"
14
15using namespace mlir;
16
17namespace {
18struct AnOpRewritePattern : OpRewritePattern<test::OpA> {
19 AnOpRewritePattern(MLIRContext *context)
20 : OpRewritePattern(context, /*benefit=*/1,
21 /*generatedNames=*/{test::OpB::getOperationName()}) {}
22
23 LogicalResult matchAndRewrite(test::OpA op,
24 PatternRewriter &rewriter) const override {
25 return failure();
26 }
27};
28TEST(OpRewritePatternTest, GetGeneratedNames) {
29 MLIRContext context;
30 AnOpRewritePattern pattern(&context);
31 ArrayRef<OperationName> ops = pattern.getGeneratedOps();
32
33 ASSERT_EQ(ops.size(), 1u);
34 ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName());
35}
36} // end anonymous namespace
37
38namespace {
39LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) {
40 return failure();
41}
42TEST(AnOpRewritePatternTest, PatternFuncAttributes) {
43 MLIRContext context;
44 RewritePatternSet patterns(&context);
45
46 patterns.add(anOpRewritePatternFunc, /*benefit=*/3,
47 /*generatedNames=*/{test::OpB::getOperationName()});
48 ASSERT_EQ(patterns.getNativePatterns().size(), 1U);
49 auto &pattern = patterns.getNativePatterns().front();
50 ASSERT_EQ(pattern->getBenefit(), 3);
51 ASSERT_EQ(pattern->getGeneratedOps().size(), 1U);
52 ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(),
53 test::OpB::getOperationName());
54}
55} // end anonymous namespace
56

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/unittests/IR/PatternMatchTest.cpp