1//===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
10#include "mlir/Dialect/PDL/IR/PDLOps.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/OpImplementation.h"
13#include "mlir/Rewrite/FrozenRewritePatternSet.h"
14#include "mlir/Rewrite/PatternApplicator.h"
15#include "llvm/ADT/ScopeExit.h"
16
17using namespace mlir;
18
19MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
20
21#define GET_OP_CLASSES
22#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
23
24//===----------------------------------------------------------------------===//
25// PatternApplicatorExtension
26//===----------------------------------------------------------------------===//
27
28namespace {
29/// A TransformState extension that keeps track of compiled PDL pattern sets.
30/// This is intended to be used along the WithPDLPatterns op. The extension
31/// can be constructed given an operation that has a SymbolTable trait and
32/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
33/// by one when requested; this behavior is subject to change.
34class PatternApplicatorExtension : public transform::TransformState::Extension {
35public:
36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
37
38 /// Creates the extension for patterns contained in `patternContainer`.
39 explicit PatternApplicatorExtension(transform::TransformState &state,
40 Operation *patternContainer)
41 : Extension(state), patterns(patternContainer) {}
42
43 /// Appends to `results` the operations contained in `root` that matched the
44 /// PDL pattern with the given name. Note that `root` may or may not be the
45 /// operation that contains PDL patterns. Reports an error if the pattern
46 /// cannot be found. Note that when no operations are matched, this still
47 /// succeeds as long as the pattern exists.
48 LogicalResult findAllMatches(StringRef patternName, Operation *root,
49 SmallVectorImpl<Operation *> &results);
50
51private:
52 /// Map from the pattern name to a singleton set of rewrite patterns that only
53 /// contains the pattern with this name. Populated when the pattern is first
54 /// requested.
55 // TODO: reconsider the efficiency of this storage when more usage data is
56 // available. Storing individual patterns in a set and triggering compilation
57 // for each of them has overhead. So does compiling a large set of patterns
58 // only to apply a handful of them.
59 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
60
61 /// A symbol table operation containing the relevant PDL patterns.
62 SymbolTable patterns;
63};
64
65LogicalResult PatternApplicatorExtension::findAllMatches(
66 StringRef patternName, Operation *root,
67 SmallVectorImpl<Operation *> &results) {
68 auto it = compiledPatterns.find(patternName);
69 if (it == compiledPatterns.end()) {
70 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
71 if (!patternOp)
72 return failure();
73
74 // Copy the pattern operation into a new module that is compiled and
75 // consumed by the PDL interpreter.
76 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
77 auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
78 builder.clone(*patternOp);
79 PDLPatternModule patternModule(std::move(pdlModuleOp));
80
81 // Merge in the hooks owned by the dialect. Make a copy as they may be
82 // also used by the following operations.
83 auto *dialect =
84 root->getContext()->getLoadedDialect<transform::TransformDialect>();
85 for (const auto &[name, constraintFn] :
86 dialect->getExtraData<transform::PDLMatchHooks>()
87 .getPDLConstraintHooks()) {
88 patternModule.registerConstraintFunction(name, constraintFn);
89 }
90
91 // Register a noop rewriter because PDL requires patterns to end with some
92 // rewrite call.
93 patternModule.registerRewriteFunction(
94 "transform.dialect", [](PatternRewriter &, Operation *) {});
95
96 it = compiledPatterns
97 .try_emplace(patternOp.getName(), std::move(patternModule))
98 .first;
99 }
100
101 PatternApplicator applicator(it->second);
102 // We want to discourage direct use of PatternRewriter in APIs but In this
103 // very specific case, an IRRewriter is not enough.
104 struct TrivialPatternRewriter : public PatternRewriter {
105 public:
106 explicit TrivialPatternRewriter(MLIRContext *context)
107 : PatternRewriter(context) {}
108 };
109 TrivialPatternRewriter rewriter(root->getContext());
110 applicator.applyDefaultCostModel();
111 root->walk([&](Operation *op) {
112 if (succeeded(result: applicator.matchAndRewrite(op, rewriter)))
113 results.push_back(Elt: op);
114 });
115
116 return success();
117}
118} // namespace
119
120//===----------------------------------------------------------------------===//
121// PDLMatchHooks
122//===----------------------------------------------------------------------===//
123
124void transform::PDLMatchHooks::mergeInPDLMatchHooks(
125 llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
126 // Steal the constraint functions from the given map.
127 for (auto &it : constraintFns)
128 pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
129}
130
131const llvm::StringMap<PDLConstraintFunction> &
132transform::PDLMatchHooks::getPDLConstraintHooks() const {
133 return pdlMatchHooks.getConstraintFunctions();
134}
135
136//===----------------------------------------------------------------------===//
137// PDLMatchOp
138//===----------------------------------------------------------------------===//
139
140DiagnosedSilenceableFailure
141transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter,
142 transform::TransformResults &results,
143 transform::TransformState &state) {
144 auto *extension = state.getExtension<PatternApplicatorExtension>();
145 assert(extension &&
146 "expected PatternApplicatorExtension to be attached by the parent op");
147 SmallVector<Operation *> targets;
148 for (Operation *root : state.getPayloadOps(getRoot())) {
149 if (failed(extension->findAllMatches(
150 getPatternName().getLeafReference().getValue(), root, targets))) {
151 emitDefiniteFailure()
152 << "could not find pattern '" << getPatternName() << "'";
153 }
154 }
155 results.set(llvm::cast<OpResult>(getResult()), targets);
156 return DiagnosedSilenceableFailure::success();
157}
158
159void transform::PDLMatchOp::getEffects(
160 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
161 onlyReadsHandle(getRoot(), effects);
162 producesHandle(getMatched(), effects);
163 onlyReadsPayload(effects);
164}
165
166//===----------------------------------------------------------------------===//
167// WithPDLPatternsOp
168//===----------------------------------------------------------------------===//
169
170DiagnosedSilenceableFailure
171transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
172 transform::TransformResults &results,
173 transform::TransformState &state) {
174 TransformOpInterface transformOp = nullptr;
175 for (Operation &nested : getBody().front()) {
176 if (!isa<pdl::PatternOp>(nested)) {
177 transformOp = cast<TransformOpInterface>(nested);
178 break;
179 }
180 }
181
182 state.addExtension<PatternApplicatorExtension>(getOperation());
183 auto guard = llvm::make_scope_exit(
184 [&]() { state.removeExtension<PatternApplicatorExtension>(); });
185
186 auto scope = state.make_region_scope(getBody());
187 if (failed(mapBlockArguments(state)))
188 return DiagnosedSilenceableFailure::definiteFailure();
189 return state.applyTransform(transformOp);
190}
191
192void transform::WithPDLPatternsOp::getEffects(
193 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
194 getPotentialTopLevelEffects(effects);
195}
196
197LogicalResult transform::WithPDLPatternsOp::verify() {
198 Block *body = getBodyBlock();
199 Operation *topLevelOp = nullptr;
200 for (Operation &op : body->getOperations()) {
201 if (isa<pdl::PatternOp>(op))
202 continue;
203
204 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
205 if (topLevelOp) {
206 InFlightDiagnostic diag =
207 emitOpError() << "expects only one non-pattern op in its body";
208 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
209 diag.attachNote(op.getLoc()) << "second non-pattern op";
210 return diag;
211 }
212 topLevelOp = &op;
213 continue;
214 }
215
216 InFlightDiagnostic diag =
217 emitOpError()
218 << "expects only pattern and top-level transform ops in its body";
219 diag.attachNote(op.getLoc()) << "offending op";
220 return diag;
221 }
222
223 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
224 InFlightDiagnostic diag = emitOpError() << "cannot be nested";
225 diag.attachNote(parent.getLoc()) << "parent operation";
226 return diag;
227 }
228
229 if (!topLevelOp) {
230 InFlightDiagnostic diag = emitOpError()
231 << "expects at least one non-pattern op";
232 return diag;
233 }
234
235 return success();
236}
237

source code of mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp