1//===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Rewrite/FrozenRewritePatternSet.h"
10#include "ByteCode.h"
11#include "mlir/Interfaces/SideEffectInterfaces.h"
12#include "mlir/Pass/Pass.h"
13#include "mlir/Pass/PassManager.h"
14#include <optional>
15
16using namespace mlir;
17
18// Include the PDL rewrite support.
19#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
20#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
21#include "mlir/Dialect/PDL/IR/PDLOps.h"
22
23static LogicalResult
24convertPDLToPDLInterp(ModuleOp pdlModule,
25 DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
26 // Skip the conversion if the module doesn't contain pdl.
27 if (pdlModule.getOps<pdl::PatternOp>().empty())
28 return success();
29
30 // Simplify the provided PDL module. Note that we can't use the canonicalizer
31 // here because it would create a cyclic dependency.
32 auto simplifyFn = [](Operation *op) {
33 // TODO: Add folding here if ever necessary.
34 if (isOpTriviallyDead(op))
35 op->erase();
36 };
37 pdlModule.getBody()->walk(simplifyFn);
38
39 /// Lower the PDL pattern module to the interpreter dialect.
40 PassManager pdlPipeline(pdlModule->getName());
41#ifdef NDEBUG
42 // We don't want to incur the hit of running the verifier when in release
43 // mode.
44 pdlPipeline.enableVerifier(false);
45#endif
46 pdlPipeline.addPass(createConvertPDLToPDLInterpPass(configMap));
47 if (failed(pdlPipeline.run(op: pdlModule)))
48 return failure();
49
50 // Simplify again after running the lowering pipeline.
51 pdlModule.getBody()->walk(simplifyFn);
52 return success();
53}
54#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
55
56//===----------------------------------------------------------------------===//
57// FrozenRewritePatternSet
58//===----------------------------------------------------------------------===//
59
60FrozenRewritePatternSet::FrozenRewritePatternSet()
61 : impl(std::make_shared<Impl>()) {}
62
63FrozenRewritePatternSet::FrozenRewritePatternSet(
64 RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
65 ArrayRef<std::string> enabledPatternLabels)
66 : impl(std::make_shared<Impl>()) {
67 DenseSet<StringRef> disabledPatterns, enabledPatterns;
68 disabledPatterns.insert_range(R&: disabledPatternLabels);
69 enabledPatterns.insert_range(R&: enabledPatternLabels);
70
71 // Functor used to walk all of the operations registered in the context. This
72 // is useful for patterns that get applied to multiple operations, such as
73 // interface and trait based patterns.
74 std::vector<RegisteredOperationName> opInfos;
75 auto addToOpsWhen =
76 [&](std::unique_ptr<RewritePattern> &pattern,
77 function_ref<bool(RegisteredOperationName)> callbackFn) {
78 if (opInfos.empty())
79 opInfos = pattern->getContext()->getRegisteredOperations();
80 for (RegisteredOperationName info : opInfos)
81 if (callbackFn(info))
82 impl->nativeOpSpecificPatternMap[info].push_back(x: pattern.get());
83 impl->nativeOpSpecificPatternList.push_back(x: std::move(pattern));
84 };
85
86 for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
87 // Don't add patterns that haven't been enabled by the user.
88 if (!enabledPatterns.empty()) {
89 auto isEnabledFn = [&](StringRef label) {
90 return enabledPatterns.count(V: label);
91 };
92 if (!isEnabledFn(pat->getDebugName()) &&
93 llvm::none_of(Range: pat->getDebugLabels(), P: isEnabledFn))
94 continue;
95 }
96 // Don't add patterns that have been disabled by the user.
97 if (!disabledPatterns.empty()) {
98 auto isDisabledFn = [&](StringRef label) {
99 return disabledPatterns.count(V: label);
100 };
101 if (isDisabledFn(pat->getDebugName()) ||
102 llvm::any_of(Range: pat->getDebugLabels(), P: isDisabledFn))
103 continue;
104 }
105
106 if (std::optional<OperationName> rootName = pat->getRootKind()) {
107 impl->nativeOpSpecificPatternMap[*rootName].push_back(x: pat.get());
108 impl->nativeOpSpecificPatternList.push_back(x: std::move(pat));
109 continue;
110 }
111 if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
112 addToOpsWhen(pat, [&](RegisteredOperationName info) {
113 return info.hasInterface(interfaceID: *interfaceID);
114 });
115 continue;
116 }
117 if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
118 addToOpsWhen(pat, [&](RegisteredOperationName info) {
119 return info.hasTrait(traitID: *traitID);
120 });
121 continue;
122 }
123 impl->nativeAnyOpPatterns.push_back(x: std::move(pat));
124 }
125
126#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
127 // Generate the bytecode for the PDL patterns if any were provided.
128 PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
129 ModuleOp pdlModule = pdlPatterns.getModule();
130 if (!pdlModule)
131 return;
132 DenseMap<Operation *, PDLPatternConfigSet *> configMap =
133 pdlPatterns.takeConfigMap();
134 if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
135 llvm::report_fatal_error(
136 reason: "failed to lower PDL pattern module to the PDL Interpreter");
137
138 // Generate the pdl bytecode.
139 impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
140 pdlModule, pdlPatterns.takeConfigs(), configMap,
141 pdlPatterns.takeConstraintFunctions(),
142 pdlPatterns.takeRewriteFunctions());
143#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
144}
145
146FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;
147

source code of mlir/lib/Rewrite/FrozenRewritePatternSet.cpp