1//===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
10#include "mlir/Dialect/PDL/IR/PDLOps.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/Rewrite/FrozenRewritePatternSet.h"
13#include "mlir/Rewrite/PatternApplicator.h"
14#include "llvm/ADT/ScopeExit.h"
15
16using namespace mlir;
17
18MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
19
20#define GET_OP_CLASSES
21#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
22
23//===----------------------------------------------------------------------===//
24// PatternApplicatorExtension
25//===----------------------------------------------------------------------===//
26
27namespace {
28/// A TransformState extension that keeps track of compiled PDL pattern sets.
29/// This is intended to be used along the WithPDLPatterns op. The extension
30/// can be constructed given an operation that has a SymbolTable trait and
31/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
32/// by one when requested; this behavior is subject to change.
33class PatternApplicatorExtension : public transform::TransformState::Extension {
34public:
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
36
37 /// Creates the extension for patterns contained in `patternContainer`.
38 explicit PatternApplicatorExtension(transform::TransformState &state,
39 Operation *patternContainer)
40 : Extension(state), patterns(patternContainer) {}
41
42 /// Appends to `results` the operations contained in `root` that matched the
43 /// PDL pattern with the given name. Note that `root` may or may not be the
44 /// operation that contains PDL patterns. Reports an error if the pattern
45 /// cannot be found. Note that when no operations are matched, this still
46 /// succeeds as long as the pattern exists.
47 LogicalResult findAllMatches(StringRef patternName, Operation *root,
48 SmallVectorImpl<Operation *> &results);
49
50private:
51 /// Map from the pattern name to a singleton set of rewrite patterns that only
52 /// contains the pattern with this name. Populated when the pattern is first
53 /// requested.
54 // TODO: reconsider the efficiency of this storage when more usage data is
55 // available. Storing individual patterns in a set and triggering compilation
56 // for each of them has overhead. So does compiling a large set of patterns
57 // only to apply a handful of them.
58 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
59
60 /// A symbol table operation containing the relevant PDL patterns.
61 SymbolTable patterns;
62};
63
64LogicalResult PatternApplicatorExtension::findAllMatches(
65 StringRef patternName, Operation *root,
66 SmallVectorImpl<Operation *> &results) {
67 auto it = compiledPatterns.find(Key: patternName);
68 if (it == compiledPatterns.end()) {
69 auto patternOp = patterns.lookup<pdl::PatternOp>(name: patternName);
70 if (!patternOp)
71 return failure();
72
73 // Copy the pattern operation into a new module that is compiled and
74 // consumed by the PDL interpreter.
75 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(loc: patternOp.getLoc());
76 auto builder = OpBuilder::atBlockEnd(block: pdlModuleOp->getBody());
77 builder.clone(op&: *patternOp);
78 PDLPatternModule patternModule(std::move(pdlModuleOp));
79
80 // Merge in the hooks owned by the dialect. Make a copy as they may be
81 // also used by the following operations.
82 auto *dialect =
83 root->getContext()->getLoadedDialect<transform::TransformDialect>();
84 for (const auto &[name, constraintFn] :
85 dialect->getExtraData<transform::PDLMatchHooks>()
86 .getPDLConstraintHooks()) {
87 patternModule.registerConstraintFunction(name, constraintFn);
88 }
89
90 // Register a noop rewriter because PDL requires patterns to end with some
91 // rewrite call.
92 patternModule.registerRewriteFunction(
93 name: "transform.dialect", rewriteFn: [](PatternRewriter &, Operation *) {});
94
95 it = compiledPatterns
96 .try_emplace(Key: patternOp.getName(), Args: std::move(patternModule))
97 .first;
98 }
99
100 PatternApplicator applicator(it->second);
101 // We want to discourage direct use of PatternRewriter in APIs but In this
102 // very specific case, an IRRewriter is not enough.
103 struct TrivialPatternRewriter : public PatternRewriter {
104 public:
105 explicit TrivialPatternRewriter(MLIRContext *context)
106 : PatternRewriter(context) {}
107 };
108 TrivialPatternRewriter rewriter(root->getContext());
109 applicator.applyDefaultCostModel();
110 root->walk(callback: [&](Operation *op) {
111 if (succeeded(Result: applicator.matchAndRewrite(op, rewriter)))
112 results.push_back(Elt: op);
113 });
114
115 return success();
116}
117} // namespace
118
119//===----------------------------------------------------------------------===//
120// PDLMatchHooks
121//===----------------------------------------------------------------------===//
122
123void transform::PDLMatchHooks::mergeInPDLMatchHooks(
124 llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
125 // Steal the constraint functions from the given map.
126 for (auto &it : constraintFns)
127 pdlMatchHooks.registerConstraintFunction(name: it.getKey(), constraintFn: std::move(it.second));
128}
129
130const llvm::StringMap<PDLConstraintFunction> &
131transform::PDLMatchHooks::getPDLConstraintHooks() const {
132 return pdlMatchHooks.getConstraintFunctions();
133}
134
135//===----------------------------------------------------------------------===//
136// PDLMatchOp
137//===----------------------------------------------------------------------===//
138
139DiagnosedSilenceableFailure
140transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter,
141 transform::TransformResults &results,
142 transform::TransformState &state) {
143 auto *extension = state.getExtension<PatternApplicatorExtension>();
144 assert(extension &&
145 "expected PatternApplicatorExtension to be attached by the parent op");
146 SmallVector<Operation *> targets;
147 for (Operation *root : state.getPayloadOps(value: getRoot())) {
148 if (failed(Result: extension->findAllMatches(
149 patternName: getPatternName().getLeafReference().getValue(), root, results&: targets))) {
150 emitDefiniteFailure()
151 << "could not find pattern '" << getPatternName() << "'";
152 }
153 }
154 results.set(value: llvm::cast<OpResult>(Val: getResult()), ops&: targets);
155 return DiagnosedSilenceableFailure::success();
156}
157
158void transform::PDLMatchOp::getEffects(
159 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
160 onlyReadsHandle(handles: getRootMutable(), effects);
161 producesHandle(handles: getOperation()->getOpResults(), effects);
162 onlyReadsPayload(effects);
163}
164
165//===----------------------------------------------------------------------===//
166// WithPDLPatternsOp
167//===----------------------------------------------------------------------===//
168
169DiagnosedSilenceableFailure
170transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
171 transform::TransformResults &results,
172 transform::TransformState &state) {
173 TransformOpInterface transformOp = nullptr;
174 for (Operation &nested : getBody().front()) {
175 if (!isa<pdl::PatternOp>(Val: nested)) {
176 transformOp = cast<TransformOpInterface>(Val&: nested);
177 break;
178 }
179 }
180
181 state.addExtension<PatternApplicatorExtension>(args: getOperation());
182 auto guard = llvm::make_scope_exit(
183 F: [&]() { state.removeExtension<PatternApplicatorExtension>(); });
184
185 auto scope = state.make_region_scope(region&: getBody());
186 if (failed(Result: mapBlockArguments(state)))
187 return DiagnosedSilenceableFailure::definiteFailure();
188 return state.applyTransform(transform: transformOp);
189}
190
191void transform::WithPDLPatternsOp::getEffects(
192 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
193 getPotentialTopLevelEffects(effects);
194}
195
196LogicalResult transform::WithPDLPatternsOp::verify() {
197 Block *body = getBodyBlock();
198 Operation *topLevelOp = nullptr;
199 for (Operation &op : body->getOperations()) {
200 if (isa<pdl::PatternOp>(Val: op))
201 continue;
202
203 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
204 if (topLevelOp) {
205 InFlightDiagnostic diag =
206 emitOpError() << "expects only one non-pattern op in its body";
207 diag.attachNote(noteLoc: topLevelOp->getLoc()) << "first non-pattern op";
208 diag.attachNote(noteLoc: op.getLoc()) << "second non-pattern op";
209 return diag;
210 }
211 topLevelOp = &op;
212 continue;
213 }
214
215 InFlightDiagnostic diag =
216 emitOpError()
217 << "expects only pattern and top-level transform ops in its body";
218 diag.attachNote(noteLoc: op.getLoc()) << "offending op";
219 return diag;
220 }
221
222 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
223 InFlightDiagnostic diag = emitOpError() << "cannot be nested";
224 diag.attachNote(noteLoc: parent.getLoc()) << "parent operation";
225 return diag;
226 }
227
228 if (!topLevelOp) {
229 InFlightDiagnostic diag = emitOpError()
230 << "expects at least one non-pattern op";
231 return diag;
232 }
233
234 return success();
235}
236

source code of mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp