1 | //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_PATTERNMATCH_H |
10 | #define MLIR_IR_PATTERNMATCH_H |
11 | |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/BuiltinOps.h" |
14 | #include "llvm/ADT/FunctionExtras.h" |
15 | #include "llvm/Support/TypeName.h" |
16 | #include <optional> |
17 | |
18 | using llvm::SmallPtrSetImpl; |
19 | namespace mlir { |
20 | |
21 | class PatternRewriter; |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // PatternBenefit class |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | /// This class represents the benefit of a pattern match in a unitless scheme |
28 | /// that ranges from 0 (very little benefit) to 65K. The most common unit to |
29 | /// use here is the "number of operations matched" by the pattern. |
30 | /// |
31 | /// This also has a sentinel representation that can be used for patterns that |
32 | /// fail to match. |
33 | /// |
34 | class PatternBenefit { |
35 | enum { ImpossibleToMatchSentinel = 65535 }; |
36 | |
37 | public: |
38 | PatternBenefit() = default; |
39 | PatternBenefit(unsigned benefit); |
40 | PatternBenefit(const PatternBenefit &) = default; |
41 | PatternBenefit &operator=(const PatternBenefit &) = default; |
42 | |
43 | static PatternBenefit impossibleToMatch() { return PatternBenefit(); } |
44 | bool isImpossibleToMatch() const { return *this == impossibleToMatch(); } |
45 | |
46 | /// If the corresponding pattern can match, return its benefit. If the |
47 | // corresponding pattern isImpossibleToMatch() then this aborts. |
48 | unsigned short getBenefit() const; |
49 | |
50 | bool operator==(const PatternBenefit &rhs) const { |
51 | return representation == rhs.representation; |
52 | } |
53 | bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); } |
54 | bool operator<(const PatternBenefit &rhs) const { |
55 | return representation < rhs.representation; |
56 | } |
57 | bool operator>(const PatternBenefit &rhs) const { return rhs < *this; } |
58 | bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); } |
59 | bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); } |
60 | |
61 | private: |
62 | unsigned short representation{ImpossibleToMatchSentinel}; |
63 | }; |
64 | |
65 | //===----------------------------------------------------------------------===// |
66 | // Pattern |
67 | //===----------------------------------------------------------------------===// |
68 | |
69 | /// This class contains all of the data related to a pattern, but does not |
70 | /// contain any methods or logic for the actual matching. This class is solely |
71 | /// used to interface with the metadata of a pattern, such as the benefit or |
72 | /// root operation. |
73 | class Pattern { |
74 | /// This enum represents the kind of value used to select the root operations |
75 | /// that match this pattern. |
76 | enum class RootKind { |
77 | /// The pattern root matches "any" operation. |
78 | Any, |
79 | /// The pattern root is matched using a concrete operation name. |
80 | OperationName, |
81 | /// The pattern root is matched using an interface ID. |
82 | InterfaceID, |
83 | /// The patter root is matched using a trait ID. |
84 | TraitID |
85 | }; |
86 | |
87 | public: |
88 | /// Return a list of operations that may be generated when rewriting an |
89 | /// operation instance with this pattern. |
90 | ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; } |
91 | |
92 | /// Return the root node that this pattern matches. Patterns that can match |
93 | /// multiple root types return std::nullopt. |
94 | std::optional<OperationName> getRootKind() const { |
95 | if (rootKind == RootKind::OperationName) |
96 | return OperationName::getFromOpaquePointer(pointer: rootValue); |
97 | return std::nullopt; |
98 | } |
99 | |
100 | /// Return the interface ID used to match the root operation of this pattern. |
101 | /// If the pattern does not use an interface ID for deciding the root match, |
102 | /// this returns std::nullopt. |
103 | std::optional<TypeID> getRootInterfaceID() const { |
104 | if (rootKind == RootKind::InterfaceID) |
105 | return TypeID::getFromOpaquePointer(pointer: rootValue); |
106 | return std::nullopt; |
107 | } |
108 | |
109 | /// Return the trait ID used to match the root operation of this pattern. |
110 | /// If the pattern does not use a trait ID for deciding the root match, this |
111 | /// returns std::nullopt. |
112 | std::optional<TypeID> getRootTraitID() const { |
113 | if (rootKind == RootKind::TraitID) |
114 | return TypeID::getFromOpaquePointer(pointer: rootValue); |
115 | return std::nullopt; |
116 | } |
117 | |
118 | /// Return the benefit (the inverse of "cost") of matching this pattern. The |
119 | /// benefit of a Pattern is always static - rewrites that may have dynamic |
120 | /// benefit can be instantiated multiple times (different Pattern instances) |
121 | /// for each benefit that they may return, and be guarded by different match |
122 | /// condition predicates. |
123 | PatternBenefit getBenefit() const { return benefit; } |
124 | |
125 | /// Returns true if this pattern is known to result in recursive application, |
126 | /// i.e. this pattern may generate IR that also matches this pattern, but is |
127 | /// known to bound the recursion. This signals to a rewrite driver that it is |
128 | /// safe to apply this pattern recursively to generated IR. |
129 | bool hasBoundedRewriteRecursion() const { |
130 | return contextAndHasBoundedRecursion.getInt(); |
131 | } |
132 | |
133 | /// Return the MLIRContext used to create this pattern. |
134 | MLIRContext *getContext() const { |
135 | return contextAndHasBoundedRecursion.getPointer(); |
136 | } |
137 | |
138 | /// Return a readable name for this pattern. This name should only be used for |
139 | /// debugging purposes, and may be empty. |
140 | StringRef getDebugName() const { return debugName; } |
141 | |
142 | /// Set the human readable debug name used for this pattern. This name will |
143 | /// only be used for debugging purposes. |
144 | void setDebugName(StringRef name) { debugName = name; } |
145 | |
146 | /// Return the set of debug labels attached to this pattern. |
147 | ArrayRef<StringRef> getDebugLabels() const { return debugLabels; } |
148 | |
149 | /// Add the provided debug labels to this pattern. |
150 | void addDebugLabels(ArrayRef<StringRef> labels) { |
151 | debugLabels.append(in_start: labels.begin(), in_end: labels.end()); |
152 | } |
153 | void addDebugLabels(StringRef label) { debugLabels.push_back(Elt: label); } |
154 | |
155 | protected: |
156 | /// This class acts as a special tag that makes the desire to match "any" |
157 | /// operation type explicit. This helps to avoid unnecessary usages of this |
158 | /// feature, and ensures that the user is making a conscious decision. |
159 | struct MatchAnyOpTypeTag {}; |
160 | /// This class acts as a special tag that makes the desire to match any |
161 | /// operation that implements a given interface explicit. This helps to avoid |
162 | /// unnecessary usages of this feature, and ensures that the user is making a |
163 | /// conscious decision. |
164 | struct MatchInterfaceOpTypeTag {}; |
165 | /// This class acts as a special tag that makes the desire to match any |
166 | /// operation that implements a given trait explicit. This helps to avoid |
167 | /// unnecessary usages of this feature, and ensures that the user is making a |
168 | /// conscious decision. |
169 | struct MatchTraitOpTypeTag {}; |
170 | |
171 | /// Construct a pattern with a certain benefit that matches the operation |
172 | /// with the given root name. |
173 | Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, |
174 | ArrayRef<StringRef> generatedNames = {}); |
175 | /// Construct a pattern that may match any operation type. `generatedNames` |
176 | /// contains the names of operations that may be generated during a successful |
177 | /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" |
178 | /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should |
179 | /// always be supplied here. |
180 | Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, |
181 | ArrayRef<StringRef> generatedNames = {}); |
182 | /// Construct a pattern that may match any operation that implements the |
183 | /// interface defined by the provided `interfaceID`. `generatedNames` contains |
184 | /// the names of operations that may be generated during a successful rewrite. |
185 | /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match |
186 | /// interface" behavior is what the user actually desired, |
187 | /// `MatchInterfaceOpTypeTag()` should always be supplied here. |
188 | Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID, |
189 | PatternBenefit benefit, MLIRContext *context, |
190 | ArrayRef<StringRef> generatedNames = {}); |
191 | /// Construct a pattern that may match any operation that implements the |
192 | /// trait defined by the provided `traitID`. `generatedNames` contains the |
193 | /// names of operations that may be generated during a successful rewrite. |
194 | /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait" |
195 | /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should |
196 | /// always be supplied here. |
197 | Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, |
198 | MLIRContext *context, ArrayRef<StringRef> generatedNames = {}); |
199 | |
200 | /// Set the flag detailing if this pattern has bounded rewrite recursion or |
201 | /// not. |
202 | void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) { |
203 | contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg); |
204 | } |
205 | |
206 | private: |
207 | Pattern(const void *rootValue, RootKind rootKind, |
208 | ArrayRef<StringRef> generatedNames, PatternBenefit benefit, |
209 | MLIRContext *context); |
210 | |
211 | /// The value used to match the root operation of the pattern. |
212 | const void *rootValue; |
213 | RootKind rootKind; |
214 | |
215 | /// The expected benefit of matching this pattern. |
216 | const PatternBenefit benefit; |
217 | |
218 | /// The context this pattern was created from, and a boolean flag indicating |
219 | /// whether this pattern has bounded recursion or not. |
220 | llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion; |
221 | |
222 | /// A list of the potential operations that may be generated when rewriting |
223 | /// an op with this pattern. |
224 | SmallVector<OperationName, 2> generatedOps; |
225 | |
226 | /// A readable name for this pattern. May be empty. |
227 | StringRef debugName; |
228 | |
229 | /// The set of debug labels attached to this pattern. |
230 | SmallVector<StringRef, 0> debugLabels; |
231 | }; |
232 | |
233 | //===----------------------------------------------------------------------===// |
234 | // RewritePattern |
235 | //===----------------------------------------------------------------------===// |
236 | |
237 | /// RewritePattern is the common base class for all DAG to DAG replacements. |
238 | /// There are two possible usages of this class: |
239 | /// * Multi-step RewritePattern with "match" and "rewrite" |
240 | /// - By overloading the "match" and "rewrite" functions, the user can |
241 | /// separate the concerns of matching and rewriting. |
242 | /// * Single-step RewritePattern with "matchAndRewrite" |
243 | /// - By overloading the "matchAndRewrite" function, the user can perform |
244 | /// the rewrite in the same call as the match. |
245 | /// |
246 | class RewritePattern : public Pattern { |
247 | public: |
248 | virtual ~RewritePattern() = default; |
249 | |
250 | /// Rewrite the IR rooted at the specified operation with the result of |
251 | /// this pattern, generating any new operations with the specified |
252 | /// builder. If an unexpected error is encountered (an internal |
253 | /// compiler error), it is emitted through the normal MLIR diagnostic |
254 | /// hooks and the IR is left in a valid state. |
255 | virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; |
256 | |
257 | /// Attempt to match against code rooted at the specified operation, |
258 | /// which is the same operation code as getRootKind(). |
259 | virtual LogicalResult match(Operation *op) const; |
260 | |
261 | /// Attempt to match against code rooted at the specified operation, |
262 | /// which is the same operation code as getRootKind(). If successful, this |
263 | /// function will automatically perform the rewrite. |
264 | virtual LogicalResult matchAndRewrite(Operation *op, |
265 | PatternRewriter &rewriter) const { |
266 | if (succeeded(result: match(op))) { |
267 | rewrite(op, rewriter); |
268 | return success(); |
269 | } |
270 | return failure(); |
271 | } |
272 | |
273 | /// This method provides a convenient interface for creating and initializing |
274 | /// derived rewrite patterns of the given type `T`. |
275 | template <typename T, typename... Args> |
276 | static std::unique_ptr<T> create(Args &&...args) { |
277 | std::unique_ptr<T> pattern = |
278 | std::make_unique<T>(std::forward<Args>(args)...); |
279 | initializePattern<T>(*pattern); |
280 | |
281 | // Set a default debug name if one wasn't provided. |
282 | if (pattern->getDebugName().empty()) |
283 | pattern->setDebugName(llvm::getTypeName<T>()); |
284 | return pattern; |
285 | } |
286 | |
287 | protected: |
288 | /// Inherit the base constructors from `Pattern`. |
289 | using Pattern::Pattern; |
290 | |
291 | private: |
292 | /// Trait to check if T provides a `getOperationName` method. |
293 | template <typename T, typename... Args> |
294 | using has_initialize = decltype(std::declval<T>().initialize()); |
295 | template <typename T> |
296 | using detect_has_initialize = llvm::is_detected<has_initialize, T>; |
297 | |
298 | /// Initialize the derived pattern by calling its `initialize` method. |
299 | template <typename T> |
300 | static std::enable_if_t<detect_has_initialize<T>::value> |
301 | initializePattern(T &pattern) { |
302 | pattern.initialize(); |
303 | } |
304 | /// Empty derived pattern initializer for patterns that do not have an |
305 | /// initialize method. |
306 | template <typename T> |
307 | static std::enable_if_t<!detect_has_initialize<T>::value> |
308 | initializePattern(T &) {} |
309 | |
310 | /// An anchor for the virtual table. |
311 | virtual void anchor(); |
312 | }; |
313 | |
314 | namespace detail { |
315 | /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that |
316 | /// allows for matching and rewriting against an instance of a derived operation |
317 | /// class or Interface. |
318 | template <typename SourceOp> |
319 | struct OpOrInterfaceRewritePatternBase : public RewritePattern { |
320 | using RewritePattern::RewritePattern; |
321 | |
322 | /// Wrappers around the RewritePattern methods that pass the derived op type. |
323 | void rewrite(Operation *op, PatternRewriter &rewriter) const final { |
324 | rewrite(cast<SourceOp>(op), rewriter); |
325 | } |
326 | LogicalResult match(Operation *op) const final { |
327 | return match(cast<SourceOp>(op)); |
328 | } |
329 | LogicalResult matchAndRewrite(Operation *op, |
330 | PatternRewriter &rewriter) const final { |
331 | return matchAndRewrite(cast<SourceOp>(op), rewriter); |
332 | } |
333 | |
334 | /// Rewrite and Match methods that operate on the SourceOp type. These must be |
335 | /// overridden by the derived pattern class. |
336 | virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { |
337 | llvm_unreachable("must override rewrite or matchAndRewrite" ); |
338 | } |
339 | virtual LogicalResult match(SourceOp op) const { |
340 | llvm_unreachable("must override match or matchAndRewrite" ); |
341 | } |
342 | virtual LogicalResult matchAndRewrite(SourceOp op, |
343 | PatternRewriter &rewriter) const { |
344 | if (succeeded(match(op))) { |
345 | rewrite(op, rewriter); |
346 | return success(); |
347 | } |
348 | return failure(); |
349 | } |
350 | }; |
351 | } // namespace detail |
352 | |
353 | /// OpRewritePattern is a wrapper around RewritePattern that allows for |
354 | /// matching and rewriting against an instance of a derived operation class as |
355 | /// opposed to a raw Operation. |
356 | template <typename SourceOp> |
357 | struct OpRewritePattern |
358 | : public detail::OpOrInterfaceRewritePatternBase<SourceOp> { |
359 | /// Patterns must specify the root operation name they match against, and can |
360 | /// also specify the benefit of the pattern matching and a list of generated |
361 | /// ops. |
362 | OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1, |
363 | ArrayRef<StringRef> generatedNames = {}) |
364 | : detail::OpOrInterfaceRewritePatternBase<SourceOp>( |
365 | SourceOp::getOperationName(), benefit, context, generatedNames) {} |
366 | }; |
367 | |
368 | /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for |
369 | /// matching and rewriting against an instance of an operation interface instead |
370 | /// of a raw Operation. |
371 | template <typename SourceOp> |
372 | struct OpInterfaceRewritePattern |
373 | : public detail::OpOrInterfaceRewritePatternBase<SourceOp> { |
374 | OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) |
375 | : detail::OpOrInterfaceRewritePatternBase<SourceOp>( |
376 | Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), |
377 | benefit, context) {} |
378 | }; |
379 | |
380 | /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for |
381 | /// matching and rewriting against instances of an operation that possess a |
382 | /// given trait. |
383 | template <template <typename> class TraitType> |
384 | class OpTraitRewritePattern : public RewritePattern { |
385 | public: |
386 | OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) |
387 | : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(), |
388 | benefit, context) {} |
389 | }; |
390 | |
391 | //===----------------------------------------------------------------------===// |
392 | // RewriterBase |
393 | //===----------------------------------------------------------------------===// |
394 | |
395 | /// This class coordinates the application of a rewrite on a set of IR, |
396 | /// providing a way for clients to track mutations and create new operations. |
397 | /// This class serves as a common API for IR mutation between pattern rewrites |
398 | /// and non-pattern rewrites, and facilitates the development of shared |
399 | /// IR transformation utilities. |
400 | class RewriterBase : public OpBuilder { |
401 | public: |
402 | struct Listener : public OpBuilder::Listener { |
403 | Listener() |
404 | : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {} |
405 | |
406 | /// Notify the listener that the specified block is about to be erased. |
407 | /// At this point, the block has zero uses. |
408 | virtual void notifyBlockErased(Block *block) {} |
409 | |
410 | /// Notify the listener that the specified operation was modified in-place. |
411 | virtual void notifyOperationModified(Operation *op) {} |
412 | |
413 | /// Notify the listener that all uses of the specified operation's results |
414 | /// are about to be replaced with the results of another operation. This is |
415 | /// called before the uses of the old operation have been changed. |
416 | /// |
417 | /// By default, this function calls the "operation replaced with values" |
418 | /// notification. |
419 | virtual void notifyOperationReplaced(Operation *op, |
420 | Operation *replacement) { |
421 | notifyOperationReplaced(op, replacement: replacement->getResults()); |
422 | } |
423 | |
424 | /// Notify the listener that all uses of the specified operation's results |
425 | /// are about to be replaced with the a range of values, potentially |
426 | /// produced by other operations. This is called before the uses of the |
427 | /// operation have been changed. |
428 | virtual void notifyOperationReplaced(Operation *op, |
429 | ValueRange replacement) {} |
430 | |
431 | /// Notify the listener that the specified operation is about to be erased. |
432 | /// At this point, the operation has zero uses. |
433 | /// |
434 | /// Note: This notification is not triggered when unlinking an operation. |
435 | virtual void notifyOperationErased(Operation *op) {} |
436 | |
437 | /// Notify the listener that the specified pattern is about to be applied |
438 | /// at the specified root operation. |
439 | virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {} |
440 | |
441 | /// Notify the listener that a pattern application finished with the |
442 | /// specified status. "success" indicates that the pattern was applied |
443 | /// successfully. "failure" indicates that the pattern could not be |
444 | /// applied. The pattern may have communicated the reason for the failure |
445 | /// with `notifyMatchFailure`. |
446 | virtual void notifyPatternEnd(const Pattern &pattern, |
447 | LogicalResult status) {} |
448 | |
449 | /// Notify the listener that the pattern failed to match, and provide a |
450 | /// callback to populate a diagnostic with the reason why the failure |
451 | /// occurred. This method allows for derived listeners to optionally hook |
452 | /// into the reason why a rewrite failed, and display it to users. |
453 | virtual void |
454 | notifyMatchFailure(Location loc, |
455 | function_ref<void(Diagnostic &)> reasonCallback) {} |
456 | |
457 | static bool classof(const OpBuilder::Listener *base); |
458 | }; |
459 | |
460 | /// A listener that forwards all notifications to another listener. This |
461 | /// struct can be used as a base to create listener chains, so that multiple |
462 | /// listeners can be notified of IR changes. |
463 | struct ForwardingListener : public RewriterBase::Listener { |
464 | ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {} |
465 | |
466 | void notifyOperationInserted(Operation *op, InsertPoint previous) override { |
467 | listener->notifyOperationInserted(op, previous); |
468 | } |
469 | void notifyBlockInserted(Block *block, Region *previous, |
470 | Region::iterator previousIt) override { |
471 | listener->notifyBlockInserted(block, previous, previousIt); |
472 | } |
473 | void notifyBlockErased(Block *block) override { |
474 | if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener)) |
475 | rewriteListener->notifyBlockErased(block); |
476 | } |
477 | void notifyOperationModified(Operation *op) override { |
478 | if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener)) |
479 | rewriteListener->notifyOperationModified(op); |
480 | } |
481 | void notifyOperationReplaced(Operation *op, Operation *newOp) override { |
482 | if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener)) |
483 | rewriteListener->notifyOperationReplaced(op, replacement: newOp); |
484 | } |
485 | void notifyOperationReplaced(Operation *op, |
486 | ValueRange replacement) override { |
487 | if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener)) |
488 | rewriteListener->notifyOperationReplaced(op, replacement); |
489 | } |
490 | void notifyOperationErased(Operation *op) override { |
491 | if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener)) |
492 | rewriteListener->notifyOperationErased(op); |
493 | } |
494 | void notifyPatternBegin(const Pattern &pattern, Operation *op) override { |
495 | if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener)) |
496 | rewriteListener->notifyPatternBegin(pattern, op); |
497 | } |
498 | void notifyPatternEnd(const Pattern &pattern, |
499 | LogicalResult status) override { |
500 | if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener)) |
501 | rewriteListener->notifyPatternEnd(pattern, status); |
502 | } |
503 | void notifyMatchFailure( |
504 | Location loc, |
505 | function_ref<void(Diagnostic &)> reasonCallback) override { |
506 | if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener)) |
507 | rewriteListener->notifyMatchFailure(loc, reasonCallback); |
508 | } |
509 | |
510 | private: |
511 | OpBuilder::Listener *listener; |
512 | }; |
513 | |
514 | /// Move the blocks that belong to "region" before the given position in |
515 | /// another region "parent". The two regions must be different. The caller |
516 | /// is responsible for creating or updating the operation transferring flow |
517 | /// of control to the region and passing it the correct block arguments. |
518 | void inlineRegionBefore(Region ®ion, Region &parent, |
519 | Region::iterator before); |
520 | void inlineRegionBefore(Region ®ion, Block *before); |
521 | |
522 | /// Replace the results of the given (original) operation with the specified |
523 | /// list of values (replacements). The result types of the given op and the |
524 | /// replacements must match. The original op is erased. |
525 | virtual void replaceOp(Operation *op, ValueRange newValues); |
526 | |
527 | /// Replace the results of the given (original) operation with the specified |
528 | /// new op (replacement). The result types of the two ops must match. The |
529 | /// original op is erased. |
530 | virtual void replaceOp(Operation *op, Operation *newOp); |
531 | |
532 | /// Replace the results of the given (original) op with a new op that is |
533 | /// created without verification (replacement). The result values of the two |
534 | /// ops must match. The original op is erased. |
535 | template <typename OpTy, typename... Args> |
536 | OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { |
537 | auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); |
538 | replaceOp(op, newOp.getOperation()); |
539 | return newOp; |
540 | } |
541 | |
542 | /// This method erases an operation that is known to have no uses. |
543 | virtual void eraseOp(Operation *op); |
544 | |
545 | /// This method erases all operations in a block. |
546 | virtual void eraseBlock(Block *block); |
547 | |
548 | /// Inline the operations of block 'source' into block 'dest' before the given |
549 | /// position. The source block will be deleted and must have no uses. |
550 | /// 'argValues' is used to replace the block arguments of 'source'. |
551 | /// |
552 | /// If the source block is inserted at the end of the dest block, the dest |
553 | /// block must have no successors. Similarly, if the source block is inserted |
554 | /// somewhere in the middle (or beginning) of the dest block, the source block |
555 | /// must have no successors. Otherwise, the resulting IR would have |
556 | /// unreachable operations. |
557 | virtual void inlineBlockBefore(Block *source, Block *dest, |
558 | Block::iterator before, |
559 | ValueRange argValues = std::nullopt); |
560 | |
561 | /// Inline the operations of block 'source' before the operation 'op'. The |
562 | /// source block will be deleted and must have no uses. 'argValues' is used to |
563 | /// replace the block arguments of 'source' |
564 | /// |
565 | /// The source block must have no successors. Otherwise, the resulting IR |
566 | /// would have unreachable operations. |
567 | void inlineBlockBefore(Block *source, Operation *op, |
568 | ValueRange argValues = std::nullopt); |
569 | |
570 | /// Inline the operations of block 'source' into the end of block 'dest'. The |
571 | /// source block will be deleted and must have no uses. 'argValues' is used to |
572 | /// replace the block arguments of 'source' |
573 | /// |
574 | /// The dest block must have no successors. Otherwise, the resulting IR would |
575 | /// have unreachable operation. |
576 | void mergeBlocks(Block *source, Block *dest, |
577 | ValueRange argValues = std::nullopt); |
578 | |
579 | /// Split the operations starting at "before" (inclusive) out of the given |
580 | /// block into a new block, and return it. |
581 | Block *splitBlock(Block *block, Block::iterator before); |
582 | |
583 | /// Unlink this operation from its current block and insert it right before |
584 | /// `existingOp` which may be in the same or another block in the same |
585 | /// function. |
586 | void moveOpBefore(Operation *op, Operation *existingOp); |
587 | |
588 | /// Unlink this operation from its current block and insert it right before |
589 | /// `iterator` in the specified block. |
590 | void moveOpBefore(Operation *op, Block *block, Block::iterator iterator); |
591 | |
592 | /// Unlink this operation from its current block and insert it right after |
593 | /// `existingOp` which may be in the same or another block in the same |
594 | /// function. |
595 | void moveOpAfter(Operation *op, Operation *existingOp); |
596 | |
597 | /// Unlink this operation from its current block and insert it right after |
598 | /// `iterator` in the specified block. |
599 | void moveOpAfter(Operation *op, Block *block, Block::iterator iterator); |
600 | |
601 | /// Unlink this block and insert it right before `existingBlock`. |
602 | void moveBlockBefore(Block *block, Block *anotherBlock); |
603 | |
604 | /// Unlink this block and insert it right before the location that the given |
605 | /// iterator points to in the given region. |
606 | void moveBlockBefore(Block *block, Region *region, Region::iterator iterator); |
607 | |
608 | /// This method is used to notify the rewriter that an in-place operation |
609 | /// modification is about to happen. A call to this function *must* be |
610 | /// followed by a call to either `finalizeOpModification` or |
611 | /// `cancelOpModification`. This is a minor efficiency win (it avoids creating |
612 | /// a new operation and removing the old one) but also often allows simpler |
613 | /// code in the client. |
614 | virtual void startOpModification(Operation *op) {} |
615 | |
616 | /// This method is used to signal the end of an in-place modification of the |
617 | /// given operation. This can only be called on operations that were provided |
618 | /// to a call to `startOpModification`. |
619 | virtual void finalizeOpModification(Operation *op); |
620 | |
621 | /// This method cancels a pending in-place modification. This can only be |
622 | /// called on operations that were provided to a call to |
623 | /// `startOpModification`. |
624 | virtual void cancelOpModification(Operation *op) {} |
625 | |
626 | /// This method is a utility wrapper around an in-place modification of an |
627 | /// operation. It wraps calls to `startOpModification` and |
628 | /// `finalizeOpModification` around the given callable. |
629 | template <typename CallableT> |
630 | void modifyOpInPlace(Operation *root, CallableT &&callable) { |
631 | startOpModification(op: root); |
632 | callable(); |
633 | finalizeOpModification(op: root); |
634 | } |
635 | |
636 | /// Find uses of `from` and replace them with `to`. Also notify the listener |
637 | /// about every in-place op modification (for every use that was replaced). |
638 | void replaceAllUsesWith(Value from, Value to) { |
639 | for (OpOperand &operand : llvm::make_early_inc_range(Range: from.getUses())) { |
640 | Operation *op = operand.getOwner(); |
641 | modifyOpInPlace(root: op, callable: [&]() { operand.set(to); }); |
642 | } |
643 | } |
644 | void replaceAllUsesWith(Block *from, Block *to) { |
645 | for (BlockOperand &operand : llvm::make_early_inc_range(Range: from->getUses())) { |
646 | Operation *op = operand.getOwner(); |
647 | modifyOpInPlace(root: op, callable: [&]() { operand.set(to); }); |
648 | } |
649 | } |
650 | void replaceAllUsesWith(ValueRange from, ValueRange to) { |
651 | assert(from.size() == to.size() && "incorrect number of replacements" ); |
652 | for (auto it : llvm::zip(t&: from, u&: to)) |
653 | replaceAllUsesWith(from: std::get<0>(t&: it), to: std::get<1>(t&: it)); |
654 | } |
655 | |
656 | /// Find uses of `from` and replace them with `to`. Also notify the listener |
657 | /// about every in-place op modification (for every use that was replaced) |
658 | /// and that the `from` operation is about to be replaced. |
659 | /// |
660 | /// Note: This function cannot be called `replaceAllUsesWith` because the |
661 | /// overload resolution, when called with an op that can be implicitly |
662 | /// converted to a Value, would be ambiguous. |
663 | void replaceAllOpUsesWith(Operation *from, ValueRange to); |
664 | void replaceAllOpUsesWith(Operation *from, Operation *to); |
665 | |
666 | /// Find uses of `from` and replace them with `to` if the `functor` returns |
667 | /// true. Also notify the listener about every in-place op modification (for |
668 | /// every use that was replaced). The optional `allUsesReplaced` flag is set |
669 | /// to "true" if all uses were replaced. |
670 | void replaceUsesWithIf(Value from, Value to, |
671 | function_ref<bool(OpOperand &)> functor, |
672 | bool *allUsesReplaced = nullptr); |
673 | void replaceUsesWithIf(ValueRange from, ValueRange to, |
674 | function_ref<bool(OpOperand &)> functor, |
675 | bool *allUsesReplaced = nullptr); |
676 | // Note: This function cannot be called `replaceOpUsesWithIf` because the |
677 | // overload resolution, when called with an op that can be implicitly |
678 | // converted to a Value, would be ambiguous. |
679 | void replaceOpUsesWithIf(Operation *from, ValueRange to, |
680 | function_ref<bool(OpOperand &)> functor, |
681 | bool *allUsesReplaced = nullptr) { |
682 | replaceUsesWithIf(from: from->getResults(), to, functor, allUsesReplaced); |
683 | } |
684 | |
685 | /// Find uses of `from` within `block` and replace them with `to`. Also notify |
686 | /// the listener about every in-place op modification (for every use that was |
687 | /// replaced). The optional `allUsesReplaced` flag is set to "true" if all |
688 | /// uses were replaced. |
689 | void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, |
690 | Block *block, bool *allUsesReplaced = nullptr) { |
691 | replaceOpUsesWithIf( |
692 | from: op, to: newValues, |
693 | functor: [block](OpOperand &use) { |
694 | return block->getParentOp()->isProperAncestor(other: use.getOwner()); |
695 | }, |
696 | allUsesReplaced); |
697 | } |
698 | |
699 | /// Find uses of `from` and replace them with `to` except if the user is |
700 | /// `exceptedUser`. Also notify the listener about every in-place op |
701 | /// modification (for every use that was replaced). |
702 | void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) { |
703 | return replaceUsesWithIf(from, to, functor: [&](OpOperand &use) { |
704 | Operation *user = use.getOwner(); |
705 | return user != exceptedUser; |
706 | }); |
707 | } |
708 | void replaceAllUsesExcept(Value from, Value to, |
709 | const SmallPtrSetImpl<Operation *> &preservedUsers); |
710 | |
711 | /// Used to notify the listener that the IR failed to be rewritten because of |
712 | /// a match failure, and provide a callback to populate a diagnostic with the |
713 | /// reason why the failure occurred. This method allows for derived rewriters |
714 | /// to optionally hook into the reason why a rewrite failed, and display it to |
715 | /// users. |
716 | template <typename CallbackT> |
717 | std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> |
718 | notifyMatchFailure(Location loc, CallbackT &&reasonCallback) { |
719 | if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener)) |
720 | rewriteListener->notifyMatchFailure( |
721 | loc, reasonCallback: function_ref<void(Diagnostic &)>(reasonCallback)); |
722 | return failure(); |
723 | } |
724 | template <typename CallbackT> |
725 | std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> |
726 | notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { |
727 | if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener)) |
728 | rewriteListener->notifyMatchFailure( |
729 | loc: op->getLoc(), reasonCallback: function_ref<void(Diagnostic &)>(reasonCallback)); |
730 | return failure(); |
731 | } |
732 | template <typename ArgT> |
733 | LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) { |
734 | return notifyMatchFailure(std::forward<ArgT>(arg), |
735 | [&](Diagnostic &diag) { diag << msg; }); |
736 | } |
737 | template <typename ArgT> |
738 | LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) { |
739 | return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg)); |
740 | } |
741 | |
742 | protected: |
743 | /// Initialize the builder. |
744 | explicit RewriterBase(MLIRContext *ctx, |
745 | OpBuilder::Listener *listener = nullptr) |
746 | : OpBuilder(ctx, listener) {} |
747 | explicit RewriterBase(const OpBuilder &otherBuilder) |
748 | : OpBuilder(otherBuilder) {} |
749 | explicit RewriterBase(Operation *op, OpBuilder::Listener *listener = nullptr) |
750 | : OpBuilder(op, listener) {} |
751 | virtual ~RewriterBase(); |
752 | |
753 | private: |
754 | void operator=(const RewriterBase &) = delete; |
755 | RewriterBase(const RewriterBase &) = delete; |
756 | }; |
757 | |
758 | //===----------------------------------------------------------------------===// |
759 | // IRRewriter |
760 | //===----------------------------------------------------------------------===// |
761 | |
762 | /// This class coordinates rewriting a piece of IR outside of a pattern rewrite, |
763 | /// providing a way to keep track of the mutations made to the IR. This class |
764 | /// should only be used in situations where another `RewriterBase` instance, |
765 | /// such as a `PatternRewriter`, is not available. |
766 | class IRRewriter : public RewriterBase { |
767 | public: |
768 | explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr) |
769 | : RewriterBase(ctx, listener) {} |
770 | explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} |
771 | explicit IRRewriter(Operation *op, OpBuilder::Listener *listener = nullptr) |
772 | : RewriterBase(op, listener) {} |
773 | }; |
774 | |
775 | //===----------------------------------------------------------------------===// |
776 | // PatternRewriter |
777 | //===----------------------------------------------------------------------===// |
778 | |
779 | /// A special type of `RewriterBase` that coordinates the application of a |
780 | /// rewrite pattern on the current IR being matched, providing a way to keep |
781 | /// track of any mutations made. This class should be used to perform all |
782 | /// necessary IR mutations within a rewrite pattern, as the pattern driver may |
783 | /// be tracking various state that would be invalidated when a mutation takes |
784 | /// place. |
785 | class PatternRewriter : public RewriterBase { |
786 | public: |
787 | using RewriterBase::RewriterBase; |
788 | |
789 | /// A hook used to indicate if the pattern rewriter can recover from failure |
790 | /// during the rewrite stage of a pattern. For example, if the pattern |
791 | /// rewriter supports rollback, it may progress smoothly even if IR was |
792 | /// changed during the rewrite. |
793 | virtual bool canRecoverFromRewriteFailure() const { return false; } |
794 | }; |
795 | |
796 | } // namespace mlir |
797 | |
798 | // Optionally expose PDL pattern matching methods. |
799 | #include "PDLPatternMatch.h.inc" |
800 | |
801 | namespace mlir { |
802 | |
803 | //===----------------------------------------------------------------------===// |
804 | // RewritePatternSet |
805 | //===----------------------------------------------------------------------===// |
806 | |
807 | class RewritePatternSet { |
808 | using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; |
809 | |
810 | public: |
811 | RewritePatternSet(MLIRContext *context) : context(context) {} |
812 | |
813 | /// Construct a RewritePatternSet populated with the given pattern. |
814 | RewritePatternSet(MLIRContext *context, |
815 | std::unique_ptr<RewritePattern> pattern) |
816 | : context(context) { |
817 | nativePatterns.emplace_back(args: std::move(pattern)); |
818 | } |
819 | RewritePatternSet(PDLPatternModule &&pattern) |
820 | : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {} |
821 | |
822 | MLIRContext *getContext() const { return context; } |
823 | |
824 | /// Return the native patterns held in this list. |
825 | NativePatternListT &getNativePatterns() { return nativePatterns; } |
826 | |
827 | /// Return the PDL patterns held in this list. |
828 | PDLPatternModule &getPDLPatterns() { return pdlPatterns; } |
829 | |
830 | /// Clear out all of the held patterns in this list. |
831 | void clear() { |
832 | nativePatterns.clear(); |
833 | pdlPatterns.clear(); |
834 | } |
835 | |
836 | //===--------------------------------------------------------------------===// |
837 | // 'add' methods for adding patterns to the set. |
838 | //===--------------------------------------------------------------------===// |
839 | |
840 | /// Add an instance of each of the pattern types 'Ts' to the pattern list with |
841 | /// the given arguments. Return a reference to `this` for chaining insertions. |
842 | /// Note: ConstructorArg is necessary here to separate the two variadic lists. |
843 | template <typename... Ts, typename ConstructorArg, |
844 | typename... ConstructorArgs, |
845 | typename = std::enable_if_t<sizeof...(Ts) != 0>> |
846 | RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) { |
847 | // The following expands a call to emplace_back for each of the pattern |
848 | // types 'Ts'. |
849 | (addImpl<Ts>(/*debugLabels=*/std::nullopt, |
850 | std::forward<ConstructorArg>(arg), |
851 | std::forward<ConstructorArgs>(args)...), |
852 | ...); |
853 | return *this; |
854 | } |
855 | /// An overload of the above `add` method that allows for attaching a set |
856 | /// of debug labels to the attached patterns. This is useful for labeling |
857 | /// groups of patterns that may be shared between multiple different |
858 | /// passes/users. |
859 | template <typename... Ts, typename ConstructorArg, |
860 | typename... ConstructorArgs, |
861 | typename = std::enable_if_t<sizeof...(Ts) != 0>> |
862 | RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels, |
863 | ConstructorArg &&arg, |
864 | ConstructorArgs &&...args) { |
865 | // The following expands a call to emplace_back for each of the pattern |
866 | // types 'Ts'. |
867 | (addImpl<Ts>(debugLabels, arg, args...), ...); |
868 | return *this; |
869 | } |
870 | |
871 | /// Add an instance of each of the pattern types 'Ts'. Return a reference to |
872 | /// `this` for chaining insertions. |
873 | template <typename... Ts> |
874 | RewritePatternSet &add() { |
875 | (addImpl<Ts>(), ...); |
876 | return *this; |
877 | } |
878 | |
879 | /// Add the given native pattern to the pattern list. Return a reference to |
880 | /// `this` for chaining insertions. |
881 | RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) { |
882 | nativePatterns.emplace_back(args: std::move(pattern)); |
883 | return *this; |
884 | } |
885 | |
886 | /// Add the given PDL pattern to the pattern list. Return a reference to |
887 | /// `this` for chaining insertions. |
888 | RewritePatternSet &add(PDLPatternModule &&pattern) { |
889 | pdlPatterns.mergeIn(std::move(pattern)); |
890 | return *this; |
891 | } |
892 | |
893 | // Add a matchAndRewrite style pattern represented as a C function pointer. |
894 | template <typename OpType> |
895 | RewritePatternSet & |
896 | add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), |
897 | PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) { |
898 | struct FnPattern final : public OpRewritePattern<OpType> { |
899 | FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), |
900 | MLIRContext *context, PatternBenefit benefit, |
901 | ArrayRef<StringRef> generatedNames) |
902 | : OpRewritePattern<OpType>(context, benefit, generatedNames), |
903 | implFn(implFn) {} |
904 | |
905 | LogicalResult matchAndRewrite(OpType op, |
906 | PatternRewriter &rewriter) const override { |
907 | return implFn(op, rewriter); |
908 | } |
909 | |
910 | private: |
911 | LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); |
912 | }; |
913 | add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit, |
914 | generatedNames)); |
915 | return *this; |
916 | } |
917 | |
918 | //===--------------------------------------------------------------------===// |
919 | // Pattern Insertion |
920 | //===--------------------------------------------------------------------===// |
921 | |
922 | // TODO: These are soft deprecated in favor of the 'add' methods above. |
923 | |
924 | /// Add an instance of each of the pattern types 'Ts' to the pattern list with |
925 | /// the given arguments. Return a reference to `this` for chaining insertions. |
926 | /// Note: ConstructorArg is necessary here to separate the two variadic lists. |
927 | template <typename... Ts, typename ConstructorArg, |
928 | typename... ConstructorArgs, |
929 | typename = std::enable_if_t<sizeof...(Ts) != 0>> |
930 | RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) { |
931 | // The following expands a call to emplace_back for each of the pattern |
932 | // types 'Ts'. |
933 | (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...); |
934 | return *this; |
935 | } |
936 | |
937 | /// Add an instance of each of the pattern types 'Ts'. Return a reference to |
938 | /// `this` for chaining insertions. |
939 | template <typename... Ts> |
940 | RewritePatternSet &insert() { |
941 | (addImpl<Ts>(), ...); |
942 | return *this; |
943 | } |
944 | |
945 | /// Add the given native pattern to the pattern list. Return a reference to |
946 | /// `this` for chaining insertions. |
947 | RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) { |
948 | nativePatterns.emplace_back(args: std::move(pattern)); |
949 | return *this; |
950 | } |
951 | |
952 | /// Add the given PDL pattern to the pattern list. Return a reference to |
953 | /// `this` for chaining insertions. |
954 | RewritePatternSet &insert(PDLPatternModule &&pattern) { |
955 | pdlPatterns.mergeIn(std::move(pattern)); |
956 | return *this; |
957 | } |
958 | |
959 | // Add a matchAndRewrite style pattern represented as a C function pointer. |
960 | template <typename OpType> |
961 | RewritePatternSet & |
962 | insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) { |
963 | struct FnPattern final : public OpRewritePattern<OpType> { |
964 | FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), |
965 | MLIRContext *context) |
966 | : OpRewritePattern<OpType>(context), implFn(implFn) { |
967 | this->setDebugName(llvm::getTypeName<FnPattern>()); |
968 | } |
969 | |
970 | LogicalResult matchAndRewrite(OpType op, |
971 | PatternRewriter &rewriter) const override { |
972 | return implFn(op, rewriter); |
973 | } |
974 | |
975 | private: |
976 | LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); |
977 | }; |
978 | add(std::make_unique<FnPattern>(std::move(implFn), getContext())); |
979 | return *this; |
980 | } |
981 | |
982 | private: |
983 | /// Add an instance of the pattern type 'T'. Return a reference to `this` for |
984 | /// chaining insertions. |
985 | template <typename T, typename... Args> |
986 | std::enable_if_t<std::is_base_of<RewritePattern, T>::value> |
987 | addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) { |
988 | std::unique_ptr<T> pattern = |
989 | RewritePattern::create<T>(std::forward<Args>(args)...); |
990 | pattern->addDebugLabels(debugLabels); |
991 | nativePatterns.emplace_back(std::move(pattern)); |
992 | } |
993 | |
994 | template <typename T, typename... Args> |
995 | std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value> |
996 | addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) { |
997 | // TODO: Add the provided labels to the PDL pattern when PDL supports |
998 | // labels. |
999 | pdlPatterns.mergeIn(T(std::forward<Args>(args)...)); |
1000 | } |
1001 | |
1002 | MLIRContext *const context; |
1003 | NativePatternListT nativePatterns; |
1004 | |
1005 | // Patterns expressed with PDL. This will compile to a stub class when PDL is |
1006 | // not enabled. |
1007 | PDLPatternModule pdlPatterns; |
1008 | }; |
1009 | |
1010 | } // namespace mlir |
1011 | |
1012 | #endif // MLIR_IR_PATTERNMATCH_H |
1013 | |