1//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_PATTERNMATCH_H
10#define MLIR_IR_PATTERNMATCH_H
11
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "llvm/ADT/FunctionExtras.h"
15#include "llvm/Support/TypeName.h"
16#include <optional>
17
18using llvm::SmallPtrSetImpl;
19namespace mlir {
20
21class PatternRewriter;
22
23//===----------------------------------------------------------------------===//
24// PatternBenefit class
25//===----------------------------------------------------------------------===//
26
27/// This class represents the benefit of a pattern match in a unitless scheme
28/// that ranges from 0 (very little benefit) to 65K. The most common unit to
29/// use here is the "number of operations matched" by the pattern.
30///
31/// This also has a sentinel representation that can be used for patterns that
32/// fail to match.
33///
34class PatternBenefit {
35 enum { ImpossibleToMatchSentinel = 65535 };
36
37public:
38 PatternBenefit() = default;
39 PatternBenefit(unsigned benefit);
40 PatternBenefit(const PatternBenefit &) = default;
41 PatternBenefit &operator=(const PatternBenefit &) = default;
42
43 static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
44 bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
45
46 /// If the corresponding pattern can match, return its benefit. If the
47 // corresponding pattern isImpossibleToMatch() then this aborts.
48 unsigned short getBenefit() const;
49
50 bool operator==(const PatternBenefit &rhs) const {
51 return representation == rhs.representation;
52 }
53 bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
54 bool operator<(const PatternBenefit &rhs) const {
55 return representation < rhs.representation;
56 }
57 bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
58 bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
59 bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
60
61private:
62 unsigned short representation{ImpossibleToMatchSentinel};
63};
64
65//===----------------------------------------------------------------------===//
66// Pattern
67//===----------------------------------------------------------------------===//
68
69/// This class contains all of the data related to a pattern, but does not
70/// contain any methods or logic for the actual matching. This class is solely
71/// used to interface with the metadata of a pattern, such as the benefit or
72/// root operation.
73class Pattern {
74 /// This enum represents the kind of value used to select the root operations
75 /// that match this pattern.
76 enum class RootKind {
77 /// The pattern root matches "any" operation.
78 Any,
79 /// The pattern root is matched using a concrete operation name.
80 OperationName,
81 /// The pattern root is matched using an interface ID.
82 InterfaceID,
83 /// The patter root is matched using a trait ID.
84 TraitID
85 };
86
87public:
88 /// Return a list of operations that may be generated when rewriting an
89 /// operation instance with this pattern.
90 ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
91
92 /// Return the root node that this pattern matches. Patterns that can match
93 /// multiple root types return std::nullopt.
94 std::optional<OperationName> getRootKind() const {
95 if (rootKind == RootKind::OperationName)
96 return OperationName::getFromOpaquePointer(pointer: rootValue);
97 return std::nullopt;
98 }
99
100 /// Return the interface ID used to match the root operation of this pattern.
101 /// If the pattern does not use an interface ID for deciding the root match,
102 /// this returns std::nullopt.
103 std::optional<TypeID> getRootInterfaceID() const {
104 if (rootKind == RootKind::InterfaceID)
105 return TypeID::getFromOpaquePointer(pointer: rootValue);
106 return std::nullopt;
107 }
108
109 /// Return the trait ID used to match the root operation of this pattern.
110 /// If the pattern does not use a trait ID for deciding the root match, this
111 /// returns std::nullopt.
112 std::optional<TypeID> getRootTraitID() const {
113 if (rootKind == RootKind::TraitID)
114 return TypeID::getFromOpaquePointer(pointer: rootValue);
115 return std::nullopt;
116 }
117
118 /// Return the benefit (the inverse of "cost") of matching this pattern. The
119 /// benefit of a Pattern is always static - rewrites that may have dynamic
120 /// benefit can be instantiated multiple times (different Pattern instances)
121 /// for each benefit that they may return, and be guarded by different match
122 /// condition predicates.
123 PatternBenefit getBenefit() const { return benefit; }
124
125 /// Returns true if this pattern is known to result in recursive application,
126 /// i.e. this pattern may generate IR that also matches this pattern, but is
127 /// known to bound the recursion. This signals to a rewrite driver that it is
128 /// safe to apply this pattern recursively to generated IR.
129 bool hasBoundedRewriteRecursion() const {
130 return contextAndHasBoundedRecursion.getInt();
131 }
132
133 /// Return the MLIRContext used to create this pattern.
134 MLIRContext *getContext() const {
135 return contextAndHasBoundedRecursion.getPointer();
136 }
137
138 /// Return a readable name for this pattern. This name should only be used for
139 /// debugging purposes, and may be empty.
140 StringRef getDebugName() const { return debugName; }
141
142 /// Set the human readable debug name used for this pattern. This name will
143 /// only be used for debugging purposes.
144 void setDebugName(StringRef name) { debugName = name; }
145
146 /// Return the set of debug labels attached to this pattern.
147 ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
148
149 /// Add the provided debug labels to this pattern.
150 void addDebugLabels(ArrayRef<StringRef> labels) {
151 debugLabels.append(in_start: labels.begin(), in_end: labels.end());
152 }
153 void addDebugLabels(StringRef label) { debugLabels.push_back(Elt: label); }
154
155protected:
156 /// This class acts as a special tag that makes the desire to match "any"
157 /// operation type explicit. This helps to avoid unnecessary usages of this
158 /// feature, and ensures that the user is making a conscious decision.
159 struct MatchAnyOpTypeTag {};
160 /// This class acts as a special tag that makes the desire to match any
161 /// operation that implements a given interface explicit. This helps to avoid
162 /// unnecessary usages of this feature, and ensures that the user is making a
163 /// conscious decision.
164 struct MatchInterfaceOpTypeTag {};
165 /// This class acts as a special tag that makes the desire to match any
166 /// operation that implements a given trait explicit. This helps to avoid
167 /// unnecessary usages of this feature, and ensures that the user is making a
168 /// conscious decision.
169 struct MatchTraitOpTypeTag {};
170
171 /// Construct a pattern with a certain benefit that matches the operation
172 /// with the given root name.
173 Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
174 ArrayRef<StringRef> generatedNames = {});
175 /// Construct a pattern that may match any operation type. `generatedNames`
176 /// contains the names of operations that may be generated during a successful
177 /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
178 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
179 /// always be supplied here.
180 Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
181 ArrayRef<StringRef> generatedNames = {});
182 /// Construct a pattern that may match any operation that implements the
183 /// interface defined by the provided `interfaceID`. `generatedNames` contains
184 /// the names of operations that may be generated during a successful rewrite.
185 /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
186 /// interface" behavior is what the user actually desired,
187 /// `MatchInterfaceOpTypeTag()` should always be supplied here.
188 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
189 PatternBenefit benefit, MLIRContext *context,
190 ArrayRef<StringRef> generatedNames = {});
191 /// Construct a pattern that may match any operation that implements the
192 /// trait defined by the provided `traitID`. `generatedNames` contains the
193 /// names of operations that may be generated during a successful rewrite.
194 /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
195 /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
196 /// always be supplied here.
197 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
198 MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
199
200 /// Set the flag detailing if this pattern has bounded rewrite recursion or
201 /// not.
202 void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
203 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
204 }
205
206private:
207 Pattern(const void *rootValue, RootKind rootKind,
208 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
209 MLIRContext *context);
210
211 /// The value used to match the root operation of the pattern.
212 const void *rootValue;
213 RootKind rootKind;
214
215 /// The expected benefit of matching this pattern.
216 const PatternBenefit benefit;
217
218 /// The context this pattern was created from, and a boolean flag indicating
219 /// whether this pattern has bounded recursion or not.
220 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
221
222 /// A list of the potential operations that may be generated when rewriting
223 /// an op with this pattern.
224 SmallVector<OperationName, 2> generatedOps;
225
226 /// A readable name for this pattern. May be empty.
227 StringRef debugName;
228
229 /// The set of debug labels attached to this pattern.
230 SmallVector<StringRef, 0> debugLabels;
231};
232
233//===----------------------------------------------------------------------===//
234// RewritePattern
235//===----------------------------------------------------------------------===//
236
237/// RewritePattern is the common base class for all DAG to DAG replacements.
238/// There are two possible usages of this class:
239/// * Multi-step RewritePattern with "match" and "rewrite"
240/// - By overloading the "match" and "rewrite" functions, the user can
241/// separate the concerns of matching and rewriting.
242/// * Single-step RewritePattern with "matchAndRewrite"
243/// - By overloading the "matchAndRewrite" function, the user can perform
244/// the rewrite in the same call as the match.
245///
246class RewritePattern : public Pattern {
247public:
248 virtual ~RewritePattern() = default;
249
250 /// Rewrite the IR rooted at the specified operation with the result of
251 /// this pattern, generating any new operations with the specified
252 /// builder. If an unexpected error is encountered (an internal
253 /// compiler error), it is emitted through the normal MLIR diagnostic
254 /// hooks and the IR is left in a valid state.
255 virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
256
257 /// Attempt to match against code rooted at the specified operation,
258 /// which is the same operation code as getRootKind().
259 virtual LogicalResult match(Operation *op) const;
260
261 /// Attempt to match against code rooted at the specified operation,
262 /// which is the same operation code as getRootKind(). If successful, this
263 /// function will automatically perform the rewrite.
264 virtual LogicalResult matchAndRewrite(Operation *op,
265 PatternRewriter &rewriter) const {
266 if (succeeded(result: match(op))) {
267 rewrite(op, rewriter);
268 return success();
269 }
270 return failure();
271 }
272
273 /// This method provides a convenient interface for creating and initializing
274 /// derived rewrite patterns of the given type `T`.
275 template <typename T, typename... Args>
276 static std::unique_ptr<T> create(Args &&...args) {
277 std::unique_ptr<T> pattern =
278 std::make_unique<T>(std::forward<Args>(args)...);
279 initializePattern<T>(*pattern);
280
281 // Set a default debug name if one wasn't provided.
282 if (pattern->getDebugName().empty())
283 pattern->setDebugName(llvm::getTypeName<T>());
284 return pattern;
285 }
286
287protected:
288 /// Inherit the base constructors from `Pattern`.
289 using Pattern::Pattern;
290
291private:
292 /// Trait to check if T provides a `getOperationName` method.
293 template <typename T, typename... Args>
294 using has_initialize = decltype(std::declval<T>().initialize());
295 template <typename T>
296 using detect_has_initialize = llvm::is_detected<has_initialize, T>;
297
298 /// Initialize the derived pattern by calling its `initialize` method.
299 template <typename T>
300 static std::enable_if_t<detect_has_initialize<T>::value>
301 initializePattern(T &pattern) {
302 pattern.initialize();
303 }
304 /// Empty derived pattern initializer for patterns that do not have an
305 /// initialize method.
306 template <typename T>
307 static std::enable_if_t<!detect_has_initialize<T>::value>
308 initializePattern(T &) {}
309
310 /// An anchor for the virtual table.
311 virtual void anchor();
312};
313
314namespace detail {
315/// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
316/// allows for matching and rewriting against an instance of a derived operation
317/// class or Interface.
318template <typename SourceOp>
319struct OpOrInterfaceRewritePatternBase : public RewritePattern {
320 using RewritePattern::RewritePattern;
321
322 /// Wrappers around the RewritePattern methods that pass the derived op type.
323 void rewrite(Operation *op, PatternRewriter &rewriter) const final {
324 rewrite(cast<SourceOp>(op), rewriter);
325 }
326 LogicalResult match(Operation *op) const final {
327 return match(cast<SourceOp>(op));
328 }
329 LogicalResult matchAndRewrite(Operation *op,
330 PatternRewriter &rewriter) const final {
331 return matchAndRewrite(cast<SourceOp>(op), rewriter);
332 }
333
334 /// Rewrite and Match methods that operate on the SourceOp type. These must be
335 /// overridden by the derived pattern class.
336 virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
337 llvm_unreachable("must override rewrite or matchAndRewrite");
338 }
339 virtual LogicalResult match(SourceOp op) const {
340 llvm_unreachable("must override match or matchAndRewrite");
341 }
342 virtual LogicalResult matchAndRewrite(SourceOp op,
343 PatternRewriter &rewriter) const {
344 if (succeeded(match(op))) {
345 rewrite(op, rewriter);
346 return success();
347 }
348 return failure();
349 }
350};
351} // namespace detail
352
353/// OpRewritePattern is a wrapper around RewritePattern that allows for
354/// matching and rewriting against an instance of a derived operation class as
355/// opposed to a raw Operation.
356template <typename SourceOp>
357struct OpRewritePattern
358 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
359 /// Patterns must specify the root operation name they match against, and can
360 /// also specify the benefit of the pattern matching and a list of generated
361 /// ops.
362 OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1,
363 ArrayRef<StringRef> generatedNames = {})
364 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
365 SourceOp::getOperationName(), benefit, context, generatedNames) {}
366};
367
368/// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
369/// matching and rewriting against an instance of an operation interface instead
370/// of a raw Operation.
371template <typename SourceOp>
372struct OpInterfaceRewritePattern
373 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
374 OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
375 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
376 Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
377 benefit, context) {}
378};
379
380/// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
381/// matching and rewriting against instances of an operation that possess a
382/// given trait.
383template <template <typename> class TraitType>
384class OpTraitRewritePattern : public RewritePattern {
385public:
386 OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
387 : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
388 benefit, context) {}
389};
390
391//===----------------------------------------------------------------------===//
392// RewriterBase
393//===----------------------------------------------------------------------===//
394
395/// This class coordinates the application of a rewrite on a set of IR,
396/// providing a way for clients to track mutations and create new operations.
397/// This class serves as a common API for IR mutation between pattern rewrites
398/// and non-pattern rewrites, and facilitates the development of shared
399/// IR transformation utilities.
400class RewriterBase : public OpBuilder {
401public:
402 struct Listener : public OpBuilder::Listener {
403 Listener()
404 : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
405
406 /// Notify the listener that the specified block is about to be erased.
407 /// At this point, the block has zero uses.
408 virtual void notifyBlockErased(Block *block) {}
409
410 /// Notify the listener that the specified operation was modified in-place.
411 virtual void notifyOperationModified(Operation *op) {}
412
413 /// Notify the listener that all uses of the specified operation's results
414 /// are about to be replaced with the results of another operation. This is
415 /// called before the uses of the old operation have been changed.
416 ///
417 /// By default, this function calls the "operation replaced with values"
418 /// notification.
419 virtual void notifyOperationReplaced(Operation *op,
420 Operation *replacement) {
421 notifyOperationReplaced(op, replacement: replacement->getResults());
422 }
423
424 /// Notify the listener that all uses of the specified operation's results
425 /// are about to be replaced with the a range of values, potentially
426 /// produced by other operations. This is called before the uses of the
427 /// operation have been changed.
428 virtual void notifyOperationReplaced(Operation *op,
429 ValueRange replacement) {}
430
431 /// Notify the listener that the specified operation is about to be erased.
432 /// At this point, the operation has zero uses.
433 ///
434 /// Note: This notification is not triggered when unlinking an operation.
435 virtual void notifyOperationErased(Operation *op) {}
436
437 /// Notify the listener that the specified pattern is about to be applied
438 /// at the specified root operation.
439 virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
440
441 /// Notify the listener that a pattern application finished with the
442 /// specified status. "success" indicates that the pattern was applied
443 /// successfully. "failure" indicates that the pattern could not be
444 /// applied. The pattern may have communicated the reason for the failure
445 /// with `notifyMatchFailure`.
446 virtual void notifyPatternEnd(const Pattern &pattern,
447 LogicalResult status) {}
448
449 /// Notify the listener that the pattern failed to match, and provide a
450 /// callback to populate a diagnostic with the reason why the failure
451 /// occurred. This method allows for derived listeners to optionally hook
452 /// into the reason why a rewrite failed, and display it to users.
453 virtual void
454 notifyMatchFailure(Location loc,
455 function_ref<void(Diagnostic &)> reasonCallback) {}
456
457 static bool classof(const OpBuilder::Listener *base);
458 };
459
460 /// A listener that forwards all notifications to another listener. This
461 /// struct can be used as a base to create listener chains, so that multiple
462 /// listeners can be notified of IR changes.
463 struct ForwardingListener : public RewriterBase::Listener {
464 ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
465
466 void notifyOperationInserted(Operation *op, InsertPoint previous) override {
467 listener->notifyOperationInserted(op, previous);
468 }
469 void notifyBlockInserted(Block *block, Region *previous,
470 Region::iterator previousIt) override {
471 listener->notifyBlockInserted(block, previous, previousIt);
472 }
473 void notifyBlockErased(Block *block) override {
474 if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener))
475 rewriteListener->notifyBlockErased(block);
476 }
477 void notifyOperationModified(Operation *op) override {
478 if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener))
479 rewriteListener->notifyOperationModified(op);
480 }
481 void notifyOperationReplaced(Operation *op, Operation *newOp) override {
482 if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener))
483 rewriteListener->notifyOperationReplaced(op, replacement: newOp);
484 }
485 void notifyOperationReplaced(Operation *op,
486 ValueRange replacement) override {
487 if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener))
488 rewriteListener->notifyOperationReplaced(op, replacement);
489 }
490 void notifyOperationErased(Operation *op) override {
491 if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener))
492 rewriteListener->notifyOperationErased(op);
493 }
494 void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
495 if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener))
496 rewriteListener->notifyPatternBegin(pattern, op);
497 }
498 void notifyPatternEnd(const Pattern &pattern,
499 LogicalResult status) override {
500 if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener))
501 rewriteListener->notifyPatternEnd(pattern, status);
502 }
503 void notifyMatchFailure(
504 Location loc,
505 function_ref<void(Diagnostic &)> reasonCallback) override {
506 if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(Val: listener))
507 rewriteListener->notifyMatchFailure(loc, reasonCallback);
508 }
509
510 private:
511 OpBuilder::Listener *listener;
512 };
513
514 /// Move the blocks that belong to "region" before the given position in
515 /// another region "parent". The two regions must be different. The caller
516 /// is responsible for creating or updating the operation transferring flow
517 /// of control to the region and passing it the correct block arguments.
518 void inlineRegionBefore(Region &region, Region &parent,
519 Region::iterator before);
520 void inlineRegionBefore(Region &region, Block *before);
521
522 /// Replace the results of the given (original) operation with the specified
523 /// list of values (replacements). The result types of the given op and the
524 /// replacements must match. The original op is erased.
525 virtual void replaceOp(Operation *op, ValueRange newValues);
526
527 /// Replace the results of the given (original) operation with the specified
528 /// new op (replacement). The result types of the two ops must match. The
529 /// original op is erased.
530 virtual void replaceOp(Operation *op, Operation *newOp);
531
532 /// Replace the results of the given (original) op with a new op that is
533 /// created without verification (replacement). The result values of the two
534 /// ops must match. The original op is erased.
535 template <typename OpTy, typename... Args>
536 OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
537 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
538 replaceOp(op, newOp.getOperation());
539 return newOp;
540 }
541
542 /// This method erases an operation that is known to have no uses.
543 virtual void eraseOp(Operation *op);
544
545 /// This method erases all operations in a block.
546 virtual void eraseBlock(Block *block);
547
548 /// Inline the operations of block 'source' into block 'dest' before the given
549 /// position. The source block will be deleted and must have no uses.
550 /// 'argValues' is used to replace the block arguments of 'source'.
551 ///
552 /// If the source block is inserted at the end of the dest block, the dest
553 /// block must have no successors. Similarly, if the source block is inserted
554 /// somewhere in the middle (or beginning) of the dest block, the source block
555 /// must have no successors. Otherwise, the resulting IR would have
556 /// unreachable operations.
557 virtual void inlineBlockBefore(Block *source, Block *dest,
558 Block::iterator before,
559 ValueRange argValues = std::nullopt);
560
561 /// Inline the operations of block 'source' before the operation 'op'. The
562 /// source block will be deleted and must have no uses. 'argValues' is used to
563 /// replace the block arguments of 'source'
564 ///
565 /// The source block must have no successors. Otherwise, the resulting IR
566 /// would have unreachable operations.
567 void inlineBlockBefore(Block *source, Operation *op,
568 ValueRange argValues = std::nullopt);
569
570 /// Inline the operations of block 'source' into the end of block 'dest'. The
571 /// source block will be deleted and must have no uses. 'argValues' is used to
572 /// replace the block arguments of 'source'
573 ///
574 /// The dest block must have no successors. Otherwise, the resulting IR would
575 /// have unreachable operation.
576 void mergeBlocks(Block *source, Block *dest,
577 ValueRange argValues = std::nullopt);
578
579 /// Split the operations starting at "before" (inclusive) out of the given
580 /// block into a new block, and return it.
581 Block *splitBlock(Block *block, Block::iterator before);
582
583 /// Unlink this operation from its current block and insert it right before
584 /// `existingOp` which may be in the same or another block in the same
585 /// function.
586 void moveOpBefore(Operation *op, Operation *existingOp);
587
588 /// Unlink this operation from its current block and insert it right before
589 /// `iterator` in the specified block.
590 void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
591
592 /// Unlink this operation from its current block and insert it right after
593 /// `existingOp` which may be in the same or another block in the same
594 /// function.
595 void moveOpAfter(Operation *op, Operation *existingOp);
596
597 /// Unlink this operation from its current block and insert it right after
598 /// `iterator` in the specified block.
599 void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
600
601 /// Unlink this block and insert it right before `existingBlock`.
602 void moveBlockBefore(Block *block, Block *anotherBlock);
603
604 /// Unlink this block and insert it right before the location that the given
605 /// iterator points to in the given region.
606 void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
607
608 /// This method is used to notify the rewriter that an in-place operation
609 /// modification is about to happen. A call to this function *must* be
610 /// followed by a call to either `finalizeOpModification` or
611 /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
612 /// a new operation and removing the old one) but also often allows simpler
613 /// code in the client.
614 virtual void startOpModification(Operation *op) {}
615
616 /// This method is used to signal the end of an in-place modification of the
617 /// given operation. This can only be called on operations that were provided
618 /// to a call to `startOpModification`.
619 virtual void finalizeOpModification(Operation *op);
620
621 /// This method cancels a pending in-place modification. This can only be
622 /// called on operations that were provided to a call to
623 /// `startOpModification`.
624 virtual void cancelOpModification(Operation *op) {}
625
626 /// This method is a utility wrapper around an in-place modification of an
627 /// operation. It wraps calls to `startOpModification` and
628 /// `finalizeOpModification` around the given callable.
629 template <typename CallableT>
630 void modifyOpInPlace(Operation *root, CallableT &&callable) {
631 startOpModification(op: root);
632 callable();
633 finalizeOpModification(op: root);
634 }
635
636 /// Find uses of `from` and replace them with `to`. Also notify the listener
637 /// about every in-place op modification (for every use that was replaced).
638 void replaceAllUsesWith(Value from, Value to) {
639 for (OpOperand &operand : llvm::make_early_inc_range(Range: from.getUses())) {
640 Operation *op = operand.getOwner();
641 modifyOpInPlace(root: op, callable: [&]() { operand.set(to); });
642 }
643 }
644 void replaceAllUsesWith(Block *from, Block *to) {
645 for (BlockOperand &operand : llvm::make_early_inc_range(Range: from->getUses())) {
646 Operation *op = operand.getOwner();
647 modifyOpInPlace(root: op, callable: [&]() { operand.set(to); });
648 }
649 }
650 void replaceAllUsesWith(ValueRange from, ValueRange to) {
651 assert(from.size() == to.size() && "incorrect number of replacements");
652 for (auto it : llvm::zip(t&: from, u&: to))
653 replaceAllUsesWith(from: std::get<0>(t&: it), to: std::get<1>(t&: it));
654 }
655
656 /// Find uses of `from` and replace them with `to`. Also notify the listener
657 /// about every in-place op modification (for every use that was replaced)
658 /// and that the `from` operation is about to be replaced.
659 ///
660 /// Note: This function cannot be called `replaceAllUsesWith` because the
661 /// overload resolution, when called with an op that can be implicitly
662 /// converted to a Value, would be ambiguous.
663 void replaceAllOpUsesWith(Operation *from, ValueRange to);
664 void replaceAllOpUsesWith(Operation *from, Operation *to);
665
666 /// Find uses of `from` and replace them with `to` if the `functor` returns
667 /// true. Also notify the listener about every in-place op modification (for
668 /// every use that was replaced). The optional `allUsesReplaced` flag is set
669 /// to "true" if all uses were replaced.
670 void replaceUsesWithIf(Value from, Value to,
671 function_ref<bool(OpOperand &)> functor,
672 bool *allUsesReplaced = nullptr);
673 void replaceUsesWithIf(ValueRange from, ValueRange to,
674 function_ref<bool(OpOperand &)> functor,
675 bool *allUsesReplaced = nullptr);
676 // Note: This function cannot be called `replaceOpUsesWithIf` because the
677 // overload resolution, when called with an op that can be implicitly
678 // converted to a Value, would be ambiguous.
679 void replaceOpUsesWithIf(Operation *from, ValueRange to,
680 function_ref<bool(OpOperand &)> functor,
681 bool *allUsesReplaced = nullptr) {
682 replaceUsesWithIf(from: from->getResults(), to, functor, allUsesReplaced);
683 }
684
685 /// Find uses of `from` within `block` and replace them with `to`. Also notify
686 /// the listener about every in-place op modification (for every use that was
687 /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
688 /// uses were replaced.
689 void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues,
690 Block *block, bool *allUsesReplaced = nullptr) {
691 replaceOpUsesWithIf(
692 from: op, to: newValues,
693 functor: [block](OpOperand &use) {
694 return block->getParentOp()->isProperAncestor(other: use.getOwner());
695 },
696 allUsesReplaced);
697 }
698
699 /// Find uses of `from` and replace them with `to` except if the user is
700 /// `exceptedUser`. Also notify the listener about every in-place op
701 /// modification (for every use that was replaced).
702 void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
703 return replaceUsesWithIf(from, to, functor: [&](OpOperand &use) {
704 Operation *user = use.getOwner();
705 return user != exceptedUser;
706 });
707 }
708 void replaceAllUsesExcept(Value from, Value to,
709 const SmallPtrSetImpl<Operation *> &preservedUsers);
710
711 /// Used to notify the listener that the IR failed to be rewritten because of
712 /// a match failure, and provide a callback to populate a diagnostic with the
713 /// reason why the failure occurred. This method allows for derived rewriters
714 /// to optionally hook into the reason why a rewrite failed, and display it to
715 /// users.
716 template <typename CallbackT>
717 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
718 notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
719 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
720 rewriteListener->notifyMatchFailure(
721 loc, reasonCallback: function_ref<void(Diagnostic &)>(reasonCallback));
722 return failure();
723 }
724 template <typename CallbackT>
725 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
726 notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
727 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
728 rewriteListener->notifyMatchFailure(
729 loc: op->getLoc(), reasonCallback: function_ref<void(Diagnostic &)>(reasonCallback));
730 return failure();
731 }
732 template <typename ArgT>
733 LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
734 return notifyMatchFailure(std::forward<ArgT>(arg),
735 [&](Diagnostic &diag) { diag << msg; });
736 }
737 template <typename ArgT>
738 LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
739 return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
740 }
741
742protected:
743 /// Initialize the builder.
744 explicit RewriterBase(MLIRContext *ctx,
745 OpBuilder::Listener *listener = nullptr)
746 : OpBuilder(ctx, listener) {}
747 explicit RewriterBase(const OpBuilder &otherBuilder)
748 : OpBuilder(otherBuilder) {}
749 explicit RewriterBase(Operation *op, OpBuilder::Listener *listener = nullptr)
750 : OpBuilder(op, listener) {}
751 virtual ~RewriterBase();
752
753private:
754 void operator=(const RewriterBase &) = delete;
755 RewriterBase(const RewriterBase &) = delete;
756};
757
758//===----------------------------------------------------------------------===//
759// IRRewriter
760//===----------------------------------------------------------------------===//
761
762/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
763/// providing a way to keep track of the mutations made to the IR. This class
764/// should only be used in situations where another `RewriterBase` instance,
765/// such as a `PatternRewriter`, is not available.
766class IRRewriter : public RewriterBase {
767public:
768 explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
769 : RewriterBase(ctx, listener) {}
770 explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
771 explicit IRRewriter(Operation *op, OpBuilder::Listener *listener = nullptr)
772 : RewriterBase(op, listener) {}
773};
774
775//===----------------------------------------------------------------------===//
776// PatternRewriter
777//===----------------------------------------------------------------------===//
778
779/// A special type of `RewriterBase` that coordinates the application of a
780/// rewrite pattern on the current IR being matched, providing a way to keep
781/// track of any mutations made. This class should be used to perform all
782/// necessary IR mutations within a rewrite pattern, as the pattern driver may
783/// be tracking various state that would be invalidated when a mutation takes
784/// place.
785class PatternRewriter : public RewriterBase {
786public:
787 using RewriterBase::RewriterBase;
788
789 /// A hook used to indicate if the pattern rewriter can recover from failure
790 /// during the rewrite stage of a pattern. For example, if the pattern
791 /// rewriter supports rollback, it may progress smoothly even if IR was
792 /// changed during the rewrite.
793 virtual bool canRecoverFromRewriteFailure() const { return false; }
794};
795
796} // namespace mlir
797
798// Optionally expose PDL pattern matching methods.
799#include "PDLPatternMatch.h.inc"
800
801namespace mlir {
802
803//===----------------------------------------------------------------------===//
804// RewritePatternSet
805//===----------------------------------------------------------------------===//
806
807class RewritePatternSet {
808 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
809
810public:
811 RewritePatternSet(MLIRContext *context) : context(context) {}
812
813 /// Construct a RewritePatternSet populated with the given pattern.
814 RewritePatternSet(MLIRContext *context,
815 std::unique_ptr<RewritePattern> pattern)
816 : context(context) {
817 nativePatterns.emplace_back(args: std::move(pattern));
818 }
819 RewritePatternSet(PDLPatternModule &&pattern)
820 : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
821
822 MLIRContext *getContext() const { return context; }
823
824 /// Return the native patterns held in this list.
825 NativePatternListT &getNativePatterns() { return nativePatterns; }
826
827 /// Return the PDL patterns held in this list.
828 PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
829
830 /// Clear out all of the held patterns in this list.
831 void clear() {
832 nativePatterns.clear();
833 pdlPatterns.clear();
834 }
835
836 //===--------------------------------------------------------------------===//
837 // 'add' methods for adding patterns to the set.
838 //===--------------------------------------------------------------------===//
839
840 /// Add an instance of each of the pattern types 'Ts' to the pattern list with
841 /// the given arguments. Return a reference to `this` for chaining insertions.
842 /// Note: ConstructorArg is necessary here to separate the two variadic lists.
843 template <typename... Ts, typename ConstructorArg,
844 typename... ConstructorArgs,
845 typename = std::enable_if_t<sizeof...(Ts) != 0>>
846 RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
847 // The following expands a call to emplace_back for each of the pattern
848 // types 'Ts'.
849 (addImpl<Ts>(/*debugLabels=*/std::nullopt,
850 std::forward<ConstructorArg>(arg),
851 std::forward<ConstructorArgs>(args)...),
852 ...);
853 return *this;
854 }
855 /// An overload of the above `add` method that allows for attaching a set
856 /// of debug labels to the attached patterns. This is useful for labeling
857 /// groups of patterns that may be shared between multiple different
858 /// passes/users.
859 template <typename... Ts, typename ConstructorArg,
860 typename... ConstructorArgs,
861 typename = std::enable_if_t<sizeof...(Ts) != 0>>
862 RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels,
863 ConstructorArg &&arg,
864 ConstructorArgs &&...args) {
865 // The following expands a call to emplace_back for each of the pattern
866 // types 'Ts'.
867 (addImpl<Ts>(debugLabels, arg, args...), ...);
868 return *this;
869 }
870
871 /// Add an instance of each of the pattern types 'Ts'. Return a reference to
872 /// `this` for chaining insertions.
873 template <typename... Ts>
874 RewritePatternSet &add() {
875 (addImpl<Ts>(), ...);
876 return *this;
877 }
878
879 /// Add the given native pattern to the pattern list. Return a reference to
880 /// `this` for chaining insertions.
881 RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
882 nativePatterns.emplace_back(args: std::move(pattern));
883 return *this;
884 }
885
886 /// Add the given PDL pattern to the pattern list. Return a reference to
887 /// `this` for chaining insertions.
888 RewritePatternSet &add(PDLPatternModule &&pattern) {
889 pdlPatterns.mergeIn(std::move(pattern));
890 return *this;
891 }
892
893 // Add a matchAndRewrite style pattern represented as a C function pointer.
894 template <typename OpType>
895 RewritePatternSet &
896 add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
897 PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
898 struct FnPattern final : public OpRewritePattern<OpType> {
899 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
900 MLIRContext *context, PatternBenefit benefit,
901 ArrayRef<StringRef> generatedNames)
902 : OpRewritePattern<OpType>(context, benefit, generatedNames),
903 implFn(implFn) {}
904
905 LogicalResult matchAndRewrite(OpType op,
906 PatternRewriter &rewriter) const override {
907 return implFn(op, rewriter);
908 }
909
910 private:
911 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
912 };
913 add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
914 generatedNames));
915 return *this;
916 }
917
918 //===--------------------------------------------------------------------===//
919 // Pattern Insertion
920 //===--------------------------------------------------------------------===//
921
922 // TODO: These are soft deprecated in favor of the 'add' methods above.
923
924 /// Add an instance of each of the pattern types 'Ts' to the pattern list with
925 /// the given arguments. Return a reference to `this` for chaining insertions.
926 /// Note: ConstructorArg is necessary here to separate the two variadic lists.
927 template <typename... Ts, typename ConstructorArg,
928 typename... ConstructorArgs,
929 typename = std::enable_if_t<sizeof...(Ts) != 0>>
930 RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
931 // The following expands a call to emplace_back for each of the pattern
932 // types 'Ts'.
933 (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...);
934 return *this;
935 }
936
937 /// Add an instance of each of the pattern types 'Ts'. Return a reference to
938 /// `this` for chaining insertions.
939 template <typename... Ts>
940 RewritePatternSet &insert() {
941 (addImpl<Ts>(), ...);
942 return *this;
943 }
944
945 /// Add the given native pattern to the pattern list. Return a reference to
946 /// `this` for chaining insertions.
947 RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
948 nativePatterns.emplace_back(args: std::move(pattern));
949 return *this;
950 }
951
952 /// Add the given PDL pattern to the pattern list. Return a reference to
953 /// `this` for chaining insertions.
954 RewritePatternSet &insert(PDLPatternModule &&pattern) {
955 pdlPatterns.mergeIn(std::move(pattern));
956 return *this;
957 }
958
959 // Add a matchAndRewrite style pattern represented as a C function pointer.
960 template <typename OpType>
961 RewritePatternSet &
962 insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
963 struct FnPattern final : public OpRewritePattern<OpType> {
964 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
965 MLIRContext *context)
966 : OpRewritePattern<OpType>(context), implFn(implFn) {
967 this->setDebugName(llvm::getTypeName<FnPattern>());
968 }
969
970 LogicalResult matchAndRewrite(OpType op,
971 PatternRewriter &rewriter) const override {
972 return implFn(op, rewriter);
973 }
974
975 private:
976 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
977 };
978 add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
979 return *this;
980 }
981
982private:
983 /// Add an instance of the pattern type 'T'. Return a reference to `this` for
984 /// chaining insertions.
985 template <typename T, typename... Args>
986 std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
987 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
988 std::unique_ptr<T> pattern =
989 RewritePattern::create<T>(std::forward<Args>(args)...);
990 pattern->addDebugLabels(debugLabels);
991 nativePatterns.emplace_back(std::move(pattern));
992 }
993
994 template <typename T, typename... Args>
995 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
996 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
997 // TODO: Add the provided labels to the PDL pattern when PDL supports
998 // labels.
999 pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1000 }
1001
1002 MLIRContext *const context;
1003 NativePatternListT nativePatterns;
1004
1005 // Patterns expressed with PDL. This will compile to a stub class when PDL is
1006 // not enabled.
1007 PDLPatternModule pdlPatterns;
1008};
1009
1010} // namespace mlir
1011
1012#endif // MLIR_IR_PATTERNMATCH_H
1013

source code of mlir/include/mlir/IR/PatternMatch.h