1 | //===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" |
10 | #include "mlir/Dialect/PDL/IR/PDLOps.h" |
11 | #include "mlir/IR/Builders.h" |
12 | #include "mlir/IR/OpImplementation.h" |
13 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
14 | #include "mlir/Rewrite/PatternApplicator.h" |
15 | #include "llvm/ADT/ScopeExit.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks) |
20 | |
21 | #define GET_OP_CLASSES |
22 | #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc" |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // PatternApplicatorExtension |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | namespace { |
29 | /// A TransformState extension that keeps track of compiled PDL pattern sets. |
30 | /// This is intended to be used along the WithPDLPatterns op. The extension |
31 | /// can be constructed given an operation that has a SymbolTable trait and |
32 | /// contains pdl::PatternOp instances. The patterns are compiled lazily and one |
33 | /// by one when requested; this behavior is subject to change. |
34 | class PatternApplicatorExtension : public transform::TransformState::Extension { |
35 | public: |
36 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) |
37 | |
38 | /// Creates the extension for patterns contained in `patternContainer`. |
39 | explicit PatternApplicatorExtension(transform::TransformState &state, |
40 | Operation *patternContainer) |
41 | : Extension(state), patterns(patternContainer) {} |
42 | |
43 | /// Appends to `results` the operations contained in `root` that matched the |
44 | /// PDL pattern with the given name. Note that `root` may or may not be the |
45 | /// operation that contains PDL patterns. Reports an error if the pattern |
46 | /// cannot be found. Note that when no operations are matched, this still |
47 | /// succeeds as long as the pattern exists. |
48 | LogicalResult findAllMatches(StringRef patternName, Operation *root, |
49 | SmallVectorImpl<Operation *> &results); |
50 | |
51 | private: |
52 | /// Map from the pattern name to a singleton set of rewrite patterns that only |
53 | /// contains the pattern with this name. Populated when the pattern is first |
54 | /// requested. |
55 | // TODO: reconsider the efficiency of this storage when more usage data is |
56 | // available. Storing individual patterns in a set and triggering compilation |
57 | // for each of them has overhead. So does compiling a large set of patterns |
58 | // only to apply a handful of them. |
59 | llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; |
60 | |
61 | /// A symbol table operation containing the relevant PDL patterns. |
62 | SymbolTable patterns; |
63 | }; |
64 | |
65 | LogicalResult PatternApplicatorExtension::findAllMatches( |
66 | StringRef patternName, Operation *root, |
67 | SmallVectorImpl<Operation *> &results) { |
68 | auto it = compiledPatterns.find(patternName); |
69 | if (it == compiledPatterns.end()) { |
70 | auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); |
71 | if (!patternOp) |
72 | return failure(); |
73 | |
74 | // Copy the pattern operation into a new module that is compiled and |
75 | // consumed by the PDL interpreter. |
76 | OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); |
77 | auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody()); |
78 | builder.clone(*patternOp); |
79 | PDLPatternModule patternModule(std::move(pdlModuleOp)); |
80 | |
81 | // Merge in the hooks owned by the dialect. Make a copy as they may be |
82 | // also used by the following operations. |
83 | auto *dialect = |
84 | root->getContext()->getLoadedDialect<transform::TransformDialect>(); |
85 | for (const auto &[name, constraintFn] : |
86 | dialect->getExtraData<transform::PDLMatchHooks>() |
87 | .getPDLConstraintHooks()) { |
88 | patternModule.registerConstraintFunction(name, constraintFn); |
89 | } |
90 | |
91 | // Register a noop rewriter because PDL requires patterns to end with some |
92 | // rewrite call. |
93 | patternModule.registerRewriteFunction( |
94 | "transform.dialect" , [](PatternRewriter &, Operation *) {}); |
95 | |
96 | it = compiledPatterns |
97 | .try_emplace(patternOp.getName(), std::move(patternModule)) |
98 | .first; |
99 | } |
100 | |
101 | PatternApplicator applicator(it->second); |
102 | // We want to discourage direct use of PatternRewriter in APIs but In this |
103 | // very specific case, an IRRewriter is not enough. |
104 | struct TrivialPatternRewriter : public PatternRewriter { |
105 | public: |
106 | explicit TrivialPatternRewriter(MLIRContext *context) |
107 | : PatternRewriter(context) {} |
108 | }; |
109 | TrivialPatternRewriter rewriter(root->getContext()); |
110 | applicator.applyDefaultCostModel(); |
111 | root->walk([&](Operation *op) { |
112 | if (succeeded(result: applicator.matchAndRewrite(op, rewriter))) |
113 | results.push_back(Elt: op); |
114 | }); |
115 | |
116 | return success(); |
117 | } |
118 | } // namespace |
119 | |
120 | //===----------------------------------------------------------------------===// |
121 | // PDLMatchHooks |
122 | //===----------------------------------------------------------------------===// |
123 | |
124 | void transform::PDLMatchHooks::mergeInPDLMatchHooks( |
125 | llvm::StringMap<PDLConstraintFunction> &&constraintFns) { |
126 | // Steal the constraint functions from the given map. |
127 | for (auto &it : constraintFns) |
128 | pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second)); |
129 | } |
130 | |
131 | const llvm::StringMap<PDLConstraintFunction> & |
132 | transform::PDLMatchHooks::getPDLConstraintHooks() const { |
133 | return pdlMatchHooks.getConstraintFunctions(); |
134 | } |
135 | |
136 | //===----------------------------------------------------------------------===// |
137 | // PDLMatchOp |
138 | //===----------------------------------------------------------------------===// |
139 | |
140 | DiagnosedSilenceableFailure |
141 | transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter, |
142 | transform::TransformResults &results, |
143 | transform::TransformState &state) { |
144 | auto *extension = state.getExtension<PatternApplicatorExtension>(); |
145 | assert(extension && |
146 | "expected PatternApplicatorExtension to be attached by the parent op" ); |
147 | SmallVector<Operation *> targets; |
148 | for (Operation *root : state.getPayloadOps(getRoot())) { |
149 | if (failed(extension->findAllMatches( |
150 | getPatternName().getLeafReference().getValue(), root, targets))) { |
151 | emitDefiniteFailure() |
152 | << "could not find pattern '" << getPatternName() << "'" ; |
153 | } |
154 | } |
155 | results.set(llvm::cast<OpResult>(getResult()), targets); |
156 | return DiagnosedSilenceableFailure::success(); |
157 | } |
158 | |
159 | void transform::PDLMatchOp::getEffects( |
160 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
161 | onlyReadsHandle(getRoot(), effects); |
162 | producesHandle(getMatched(), effects); |
163 | onlyReadsPayload(effects); |
164 | } |
165 | |
166 | //===----------------------------------------------------------------------===// |
167 | // WithPDLPatternsOp |
168 | //===----------------------------------------------------------------------===// |
169 | |
170 | DiagnosedSilenceableFailure |
171 | transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter, |
172 | transform::TransformResults &results, |
173 | transform::TransformState &state) { |
174 | TransformOpInterface transformOp = nullptr; |
175 | for (Operation &nested : getBody().front()) { |
176 | if (!isa<pdl::PatternOp>(nested)) { |
177 | transformOp = cast<TransformOpInterface>(nested); |
178 | break; |
179 | } |
180 | } |
181 | |
182 | state.addExtension<PatternApplicatorExtension>(getOperation()); |
183 | auto guard = llvm::make_scope_exit( |
184 | [&]() { state.removeExtension<PatternApplicatorExtension>(); }); |
185 | |
186 | auto scope = state.make_region_scope(getBody()); |
187 | if (failed(mapBlockArguments(state))) |
188 | return DiagnosedSilenceableFailure::definiteFailure(); |
189 | return state.applyTransform(transformOp); |
190 | } |
191 | |
192 | void transform::WithPDLPatternsOp::getEffects( |
193 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
194 | getPotentialTopLevelEffects(effects); |
195 | } |
196 | |
197 | LogicalResult transform::WithPDLPatternsOp::verify() { |
198 | Block *body = getBodyBlock(); |
199 | Operation *topLevelOp = nullptr; |
200 | for (Operation &op : body->getOperations()) { |
201 | if (isa<pdl::PatternOp>(op)) |
202 | continue; |
203 | |
204 | if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { |
205 | if (topLevelOp) { |
206 | InFlightDiagnostic diag = |
207 | emitOpError() << "expects only one non-pattern op in its body" ; |
208 | diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op" ; |
209 | diag.attachNote(op.getLoc()) << "second non-pattern op" ; |
210 | return diag; |
211 | } |
212 | topLevelOp = &op; |
213 | continue; |
214 | } |
215 | |
216 | InFlightDiagnostic diag = |
217 | emitOpError() |
218 | << "expects only pattern and top-level transform ops in its body" ; |
219 | diag.attachNote(op.getLoc()) << "offending op" ; |
220 | return diag; |
221 | } |
222 | |
223 | if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { |
224 | InFlightDiagnostic diag = emitOpError() << "cannot be nested" ; |
225 | diag.attachNote(parent.getLoc()) << "parent operation" ; |
226 | return diag; |
227 | } |
228 | |
229 | if (!topLevelOp) { |
230 | InFlightDiagnostic diag = emitOpError() |
231 | << "expects at least one non-pattern op" ; |
232 | return diag; |
233 | } |
234 | |
235 | return success(); |
236 | } |
237 | |