1 | //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
10 | #include "ByteCode.h" |
11 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
12 | #include "mlir/Pass/Pass.h" |
13 | #include "mlir/Pass/PassManager.h" |
14 | #include <optional> |
15 | |
16 | using namespace mlir; |
17 | |
18 | // Include the PDL rewrite support. |
19 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
20 | #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" |
21 | #include "mlir/Dialect/PDL/IR/PDLOps.h" |
22 | |
23 | static LogicalResult |
24 | convertPDLToPDLInterp(ModuleOp pdlModule, |
25 | DenseMap<Operation *, PDLPatternConfigSet *> &configMap) { |
26 | // Skip the conversion if the module doesn't contain pdl. |
27 | if (pdlModule.getOps<pdl::PatternOp>().empty()) |
28 | return success(); |
29 | |
30 | // Simplify the provided PDL module. Note that we can't use the canonicalizer |
31 | // here because it would create a cyclic dependency. |
32 | auto simplifyFn = [](Operation *op) { |
33 | // TODO: Add folding here if ever necessary. |
34 | if (isOpTriviallyDead(op)) |
35 | op->erase(); |
36 | }; |
37 | pdlModule.getBody()->walk(simplifyFn); |
38 | |
39 | /// Lower the PDL pattern module to the interpreter dialect. |
40 | PassManager pdlPipeline(pdlModule->getName()); |
41 | #ifdef NDEBUG |
42 | // We don't want to incur the hit of running the verifier when in release |
43 | // mode. |
44 | pdlPipeline.enableVerifier(false); |
45 | #endif |
46 | pdlPipeline.addPass(createPDLToPDLInterpPass(configMap)); |
47 | if (failed(pdlPipeline.run(op: pdlModule))) |
48 | return failure(); |
49 | |
50 | // Simplify again after running the lowering pipeline. |
51 | pdlModule.getBody()->walk(simplifyFn); |
52 | return success(); |
53 | } |
54 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // FrozenRewritePatternSet |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | FrozenRewritePatternSet::FrozenRewritePatternSet() |
61 | : impl(std::make_shared<Impl>()) {} |
62 | |
63 | FrozenRewritePatternSet::FrozenRewritePatternSet( |
64 | RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels, |
65 | ArrayRef<std::string> enabledPatternLabels) |
66 | : impl(std::make_shared<Impl>()) { |
67 | DenseSet<StringRef> disabledPatterns, enabledPatterns; |
68 | disabledPatterns.insert(I: disabledPatternLabels.begin(), |
69 | E: disabledPatternLabels.end()); |
70 | enabledPatterns.insert(I: enabledPatternLabels.begin(), |
71 | E: enabledPatternLabels.end()); |
72 | |
73 | // Functor used to walk all of the operations registered in the context. This |
74 | // is useful for patterns that get applied to multiple operations, such as |
75 | // interface and trait based patterns. |
76 | std::vector<RegisteredOperationName> opInfos; |
77 | auto addToOpsWhen = |
78 | [&](std::unique_ptr<RewritePattern> &pattern, |
79 | function_ref<bool(RegisteredOperationName)> callbackFn) { |
80 | if (opInfos.empty()) |
81 | opInfos = pattern->getContext()->getRegisteredOperations(); |
82 | for (RegisteredOperationName info : opInfos) |
83 | if (callbackFn(info)) |
84 | impl->nativeOpSpecificPatternMap[info].push_back(x: pattern.get()); |
85 | impl->nativeOpSpecificPatternList.push_back(x: std::move(pattern)); |
86 | }; |
87 | |
88 | for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) { |
89 | // Don't add patterns that haven't been enabled by the user. |
90 | if (!enabledPatterns.empty()) { |
91 | auto isEnabledFn = [&](StringRef label) { |
92 | return enabledPatterns.count(V: label); |
93 | }; |
94 | if (!isEnabledFn(pat->getDebugName()) && |
95 | llvm::none_of(Range: pat->getDebugLabels(), P: isEnabledFn)) |
96 | continue; |
97 | } |
98 | // Don't add patterns that have been disabled by the user. |
99 | if (!disabledPatterns.empty()) { |
100 | auto isDisabledFn = [&](StringRef label) { |
101 | return disabledPatterns.count(V: label); |
102 | }; |
103 | if (isDisabledFn(pat->getDebugName()) || |
104 | llvm::any_of(Range: pat->getDebugLabels(), P: isDisabledFn)) |
105 | continue; |
106 | } |
107 | |
108 | if (std::optional<OperationName> rootName = pat->getRootKind()) { |
109 | impl->nativeOpSpecificPatternMap[*rootName].push_back(x: pat.get()); |
110 | impl->nativeOpSpecificPatternList.push_back(x: std::move(pat)); |
111 | continue; |
112 | } |
113 | if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) { |
114 | addToOpsWhen(pat, [&](RegisteredOperationName info) { |
115 | return info.hasInterface(interfaceID: *interfaceID); |
116 | }); |
117 | continue; |
118 | } |
119 | if (std::optional<TypeID> traitID = pat->getRootTraitID()) { |
120 | addToOpsWhen(pat, [&](RegisteredOperationName info) { |
121 | return info.hasTrait(traitID: *traitID); |
122 | }); |
123 | continue; |
124 | } |
125 | impl->nativeAnyOpPatterns.push_back(x: std::move(pat)); |
126 | } |
127 | |
128 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
129 | // Generate the bytecode for the PDL patterns if any were provided. |
130 | PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); |
131 | ModuleOp pdlModule = pdlPatterns.getModule(); |
132 | if (!pdlModule) |
133 | return; |
134 | DenseMap<Operation *, PDLPatternConfigSet *> configMap = |
135 | pdlPatterns.takeConfigMap(); |
136 | if (failed(convertPDLToPDLInterp(pdlModule, configMap))) |
137 | llvm::report_fatal_error( |
138 | reason: "failed to lower PDL pattern module to the PDL Interpreter" ); |
139 | |
140 | // Generate the pdl bytecode. |
141 | impl->pdlByteCode = std::make_unique<detail::PDLByteCode>( |
142 | pdlModule, pdlPatterns.takeConfigs(), configMap, |
143 | pdlPatterns.takeConstraintFunctions(), |
144 | pdlPatterns.takeRewriteFunctions()); |
145 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
146 | } |
147 | |
148 | FrozenRewritePatternSet::~FrozenRewritePatternSet() = default; |
149 | |