1//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/PatternMatch.h"
10#include "mlir/Parser/Parser.h"
11#include "mlir/Pass/PassManager.h"
12#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13#include "mlir/Transforms/Passes.h"
14#include "gtest/gtest.h"
15
16using namespace mlir;
17
18namespace {
19
20struct DisabledPattern : public RewritePattern {
21 DisabledPattern(MLIRContext *context)
22 : RewritePattern("test.foo", /*benefit=*/0, context,
23 /*generatedNamed=*/{}) {
24 setDebugName("DisabledPattern");
25 }
26
27 LogicalResult matchAndRewrite(Operation *op,
28 PatternRewriter &rewriter) const override {
29 if (op->getNumResults() != 1)
30 return failure();
31 rewriter.eraseOp(op);
32 return success();
33 }
34};
35
36struct EnabledPattern : public RewritePattern {
37 EnabledPattern(MLIRContext *context)
38 : RewritePattern("test.foo", /*benefit=*/0, context,
39 /*generatedNamed=*/{}) {
40 setDebugName("EnabledPattern");
41 }
42
43 LogicalResult matchAndRewrite(Operation *op,
44 PatternRewriter &rewriter) const override {
45 if (op->getNumResults() == 1)
46 return failure();
47 rewriter.eraseOp(op);
48 return success();
49 }
50};
51
52struct TestDialect : public Dialect {
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
54
55 static StringRef getDialectNamespace() { return "test"; }
56
57 TestDialect(MLIRContext *context)
58 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
59 allowUnknownOperations();
60 }
61
62 void getCanonicalizationPatterns(RewritePatternSet &results) const override {
63 results.add<DisabledPattern, EnabledPattern>(arg: results.getContext());
64 }
65};
66
67TEST(CanonicalizerTest, TestDisablePatterns) {
68 MLIRContext context;
69 context.getOrLoadDialect<TestDialect>();
70 PassManager mgr(&context);
71 mgr.addPass(
72 pass: createCanonicalizerPass(config: GreedyRewriteConfig(), disabledPatterns: {"DisabledPattern"}));
73
74 const char *const code = R"mlir(
75 %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
76 %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
77 )mlir";
78
79 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(sourceStr: code, config: &context);
80 ASSERT_TRUE(succeeded(mgr.run(*module)));
81
82 EXPECT_TRUE(module->lookupSymbol("B"));
83 EXPECT_FALSE(module->lookupSymbol("A"));
84}
85
86} // end anonymous namespace
87

source code of mlir/unittests/Transforms/Canonicalizer.cpp