| 1 | //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_IR_PATTERNMATCH_H |
| 10 | #define MLIR_IR_PATTERNMATCH_H |
| 11 | |
| 12 | #include "mlir/IR/Builders.h" |
| 13 | #include "mlir/IR/BuiltinOps.h" |
| 14 | #include "llvm/ADT/FunctionExtras.h" |
| 15 | #include "llvm/Support/TypeName.h" |
| 16 | #include <optional> |
| 17 | |
| 18 | using llvm::SmallPtrSetImpl; |
| 19 | namespace mlir { |
| 20 | |
| 21 | class PatternRewriter; |
| 22 | |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | // PatternBenefit class |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | |
| 27 | /// This class represents the benefit of a pattern match in a unitless scheme |
| 28 | /// that ranges from 0 (very little benefit) to 65K. The most common unit to |
| 29 | /// use here is the "number of operations matched" by the pattern. |
| 30 | /// |
| 31 | /// This also has a sentinel representation that can be used for patterns that |
| 32 | /// fail to match. |
| 33 | /// |
| 34 | class PatternBenefit { |
| 35 | enum { ImpossibleToMatchSentinel = 65535 }; |
| 36 | |
| 37 | public: |
| 38 | PatternBenefit() = default; |
| 39 | PatternBenefit(unsigned benefit); |
| 40 | PatternBenefit(const PatternBenefit &) = default; |
| 41 | PatternBenefit &operator=(const PatternBenefit &) = default; |
| 42 | |
| 43 | static PatternBenefit impossibleToMatch() { return PatternBenefit(); } |
| 44 | bool isImpossibleToMatch() const { return *this == impossibleToMatch(); } |
| 45 | |
| 46 | /// If the corresponding pattern can match, return its benefit. If the |
| 47 | // corresponding pattern isImpossibleToMatch() then this aborts. |
| 48 | unsigned short getBenefit() const; |
| 49 | |
| 50 | bool operator==(const PatternBenefit &rhs) const { |
| 51 | return representation == rhs.representation; |
| 52 | } |
| 53 | bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); } |
| 54 | bool operator<(const PatternBenefit &rhs) const { |
| 55 | return representation < rhs.representation; |
| 56 | } |
| 57 | bool operator>(const PatternBenefit &rhs) const { return rhs < *this; } |
| 58 | bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); } |
| 59 | bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); } |
| 60 | |
| 61 | private: |
| 62 | unsigned short representation{ImpossibleToMatchSentinel}; |
| 63 | }; |
| 64 | |
| 65 | //===----------------------------------------------------------------------===// |
| 66 | // Pattern |
| 67 | //===----------------------------------------------------------------------===// |
| 68 | |
| 69 | /// This class contains all of the data related to a pattern, but does not |
| 70 | /// contain any methods or logic for the actual matching. This class is solely |
| 71 | /// used to interface with the metadata of a pattern, such as the benefit or |
| 72 | /// root operation. |
| 73 | class Pattern { |
| 74 | /// This enum represents the kind of value used to select the root operations |
| 75 | /// that match this pattern. |
| 76 | enum class RootKind { |
| 77 | /// The pattern root matches "any" operation. |
| 78 | Any, |
| 79 | /// The pattern root is matched using a concrete operation name. |
| 80 | OperationName, |
| 81 | /// The pattern root is matched using an interface ID. |
| 82 | InterfaceID, |
| 83 | /// The patter root is matched using a trait ID. |
| 84 | TraitID |
| 85 | }; |
| 86 | |
| 87 | public: |
| 88 | /// Return a list of operations that may be generated when rewriting an |
| 89 | /// operation instance with this pattern. |
| 90 | ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; } |
| 91 | |
| 92 | /// Return the root node that this pattern matches. Patterns that can match |
| 93 | /// multiple root types return std::nullopt. |
| 94 | std::optional<OperationName> getRootKind() const { |
| 95 | if (rootKind == RootKind::OperationName) |
| 96 | return OperationName::getFromOpaquePointer(pointer: rootValue); |
| 97 | return std::nullopt; |
| 98 | } |
| 99 | |
| 100 | /// Return the interface ID used to match the root operation of this pattern. |
| 101 | /// If the pattern does not use an interface ID for deciding the root match, |
| 102 | /// this returns std::nullopt. |
| 103 | std::optional<TypeID> getRootInterfaceID() const { |
| 104 | if (rootKind == RootKind::InterfaceID) |
| 105 | return TypeID::getFromOpaquePointer(pointer: rootValue); |
| 106 | return std::nullopt; |
| 107 | } |
| 108 | |
| 109 | /// Return the trait ID used to match the root operation of this pattern. |
| 110 | /// If the pattern does not use a trait ID for deciding the root match, this |
| 111 | /// returns std::nullopt. |
| 112 | std::optional<TypeID> getRootTraitID() const { |
| 113 | if (rootKind == RootKind::TraitID) |
| 114 | return TypeID::getFromOpaquePointer(pointer: rootValue); |
| 115 | return std::nullopt; |
| 116 | } |
| 117 | |
| 118 | /// Return the benefit (the inverse of "cost") of matching this pattern. The |
| 119 | /// benefit of a Pattern is always static - rewrites that may have dynamic |
| 120 | /// benefit can be instantiated multiple times (different Pattern instances) |
| 121 | /// for each benefit that they may return, and be guarded by different match |
| 122 | /// condition predicates. |
| 123 | PatternBenefit getBenefit() const { return benefit; } |
| 124 | |
| 125 | /// Returns true if this pattern is known to result in recursive application, |
| 126 | /// i.e. this pattern may generate IR that also matches this pattern, but is |
| 127 | /// known to bound the recursion. This signals to a rewrite driver that it is |
| 128 | /// safe to apply this pattern recursively to generated IR. |
| 129 | bool hasBoundedRewriteRecursion() const { |
| 130 | return contextAndHasBoundedRecursion.getInt(); |
| 131 | } |
| 132 | |
| 133 | /// Return the MLIRContext used to create this pattern. |
| 134 | MLIRContext *getContext() const { |
| 135 | return contextAndHasBoundedRecursion.getPointer(); |
| 136 | } |
| 137 | |
| 138 | /// Return a readable name for this pattern. This name should only be used for |
| 139 | /// debugging purposes, and may be empty. |
| 140 | StringRef getDebugName() const { return debugName; } |
| 141 | |
| 142 | /// Set the human readable debug name used for this pattern. This name will |
| 143 | /// only be used for debugging purposes. |
| 144 | void setDebugName(StringRef name) { debugName = name; } |
| 145 | |
| 146 | /// Return the set of debug labels attached to this pattern. |
| 147 | ArrayRef<StringRef> getDebugLabels() const { return debugLabels; } |
| 148 | |
| 149 | /// Add the provided debug labels to this pattern. |
| 150 | void addDebugLabels(ArrayRef<StringRef> labels) { |
| 151 | debugLabels.append(in_start: labels.begin(), in_end: labels.end()); |
| 152 | } |
| 153 | void addDebugLabels(StringRef label) { debugLabels.push_back(Elt: label); } |
| 154 | |
| 155 | protected: |
| 156 | /// This class acts as a special tag that makes the desire to match "any" |
| 157 | /// operation type explicit. This helps to avoid unnecessary usages of this |
| 158 | /// feature, and ensures that the user is making a conscious decision. |
| 159 | struct MatchAnyOpTypeTag {}; |
| 160 | /// This class acts as a special tag that makes the desire to match any |
| 161 | /// operation that implements a given interface explicit. This helps to avoid |
| 162 | /// unnecessary usages of this feature, and ensures that the user is making a |
| 163 | /// conscious decision. |
| 164 | struct MatchInterfaceOpTypeTag {}; |
| 165 | /// This class acts as a special tag that makes the desire to match any |
| 166 | /// operation that implements a given trait explicit. This helps to avoid |
| 167 | /// unnecessary usages of this feature, and ensures that the user is making a |
| 168 | /// conscious decision. |
| 169 | struct MatchTraitOpTypeTag {}; |
| 170 | |
| 171 | /// Construct a pattern with a certain benefit that matches the operation |
| 172 | /// with the given root name. |
| 173 | Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, |
| 174 | ArrayRef<StringRef> generatedNames = {}); |
| 175 | /// Construct a pattern that may match any operation type. `generatedNames` |
| 176 | /// contains the names of operations that may be generated during a successful |
| 177 | /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" |
| 178 | /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should |
| 179 | /// always be supplied here. |
| 180 | Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, |
| 181 | ArrayRef<StringRef> generatedNames = {}); |
| 182 | /// Construct a pattern that may match any operation that implements the |
| 183 | /// interface defined by the provided `interfaceID`. `generatedNames` contains |
| 184 | /// the names of operations that may be generated during a successful rewrite. |
| 185 | /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match |
| 186 | /// interface" behavior is what the user actually desired, |
| 187 | /// `MatchInterfaceOpTypeTag()` should always be supplied here. |
| 188 | Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID, |
| 189 | PatternBenefit benefit, MLIRContext *context, |
| 190 | ArrayRef<StringRef> generatedNames = {}); |
| 191 | /// Construct a pattern that may match any operation that implements the |
| 192 | /// trait defined by the provided `traitID`. `generatedNames` contains the |
| 193 | /// names of operations that may be generated during a successful rewrite. |
| 194 | /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait" |
| 195 | /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should |
| 196 | /// always be supplied here. |
| 197 | Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, |
| 198 | MLIRContext *context, ArrayRef<StringRef> generatedNames = {}); |
| 199 | |
| 200 | /// Set the flag detailing if this pattern has bounded rewrite recursion or |
| 201 | /// not. |
| 202 | void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) { |
| 203 | contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg); |
| 204 | } |
| 205 | |
| 206 | private: |
| 207 | Pattern(const void *rootValue, RootKind rootKind, |
| 208 | ArrayRef<StringRef> generatedNames, PatternBenefit benefit, |
| 209 | MLIRContext *context); |
| 210 | |
| 211 | /// The value used to match the root operation of the pattern. |
| 212 | const void *rootValue; |
| 213 | RootKind rootKind; |
| 214 | |
| 215 | /// The expected benefit of matching this pattern. |
| 216 | const PatternBenefit benefit; |
| 217 | |
| 218 | /// The context this pattern was created from, and a boolean flag indicating |
| 219 | /// whether this pattern has bounded recursion or not. |
| 220 | llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion; |
| 221 | |
| 222 | /// A list of the potential operations that may be generated when rewriting |
| 223 | /// an op with this pattern. |
| 224 | SmallVector<OperationName, 2> generatedOps; |
| 225 | |
| 226 | /// A readable name for this pattern. May be empty. |
| 227 | StringRef debugName; |
| 228 | |
| 229 | /// The set of debug labels attached to this pattern. |
| 230 | SmallVector<StringRef, 0> debugLabels; |
| 231 | }; |
| 232 | |
| 233 | //===----------------------------------------------------------------------===// |
| 234 | // RewritePattern |
| 235 | //===----------------------------------------------------------------------===// |
| 236 | |
| 237 | /// RewritePattern is the common base class for all DAG to DAG replacements. |
| 238 | class RewritePattern : public Pattern { |
| 239 | public: |
| 240 | virtual ~RewritePattern() = default; |
| 241 | |
| 242 | /// Attempt to match against code rooted at the specified operation, |
| 243 | /// which is the same operation code as getRootKind(). If successful, perform |
| 244 | /// the rewrite. |
| 245 | /// |
| 246 | /// Note: Implementations must modify the IR if and only if the function |
| 247 | /// returns "success". |
| 248 | virtual LogicalResult matchAndRewrite(Operation *op, |
| 249 | PatternRewriter &rewriter) const = 0; |
| 250 | |
| 251 | /// This method provides a convenient interface for creating and initializing |
| 252 | /// derived rewrite patterns of the given type `T`. |
| 253 | template <typename T, typename... Args> |
| 254 | static std::unique_ptr<T> create(Args &&...args) { |
| 255 | std::unique_ptr<T> pattern = |
| 256 | std::make_unique<T>(std::forward<Args>(args)...); |
| 257 | initializePattern<T>(*pattern); |
| 258 | |
| 259 | // Set a default debug name if one wasn't provided. |
| 260 | if (pattern->getDebugName().empty()) |
| 261 | pattern->setDebugName(llvm::getTypeName<T>()); |
| 262 | return pattern; |
| 263 | } |
| 264 | |
| 265 | protected: |
| 266 | /// Inherit the base constructors from `Pattern`. |
| 267 | using Pattern::Pattern; |
| 268 | |
| 269 | private: |
| 270 | /// Trait to check if T provides a `initialize` method. |
| 271 | template <typename T, typename... Args> |
| 272 | using has_initialize = decltype(std::declval<T>().initialize()); |
| 273 | template <typename T> |
| 274 | using detect_has_initialize = llvm::is_detected<has_initialize, T>; |
| 275 | |
| 276 | /// Initialize the derived pattern by calling its `initialize` method if |
| 277 | /// available. |
| 278 | template <typename T> |
| 279 | static void initializePattern(T &pattern) { |
| 280 | if constexpr (detect_has_initialize<T>::value) |
| 281 | pattern.initialize(); |
| 282 | } |
| 283 | |
| 284 | /// An anchor for the virtual table. |
| 285 | virtual void anchor(); |
| 286 | }; |
| 287 | |
| 288 | namespace detail { |
| 289 | /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that |
| 290 | /// allows for matching and rewriting against an instance of a derived operation |
| 291 | /// class or Interface. |
| 292 | template <typename SourceOp> |
| 293 | struct OpOrInterfaceRewritePatternBase : public RewritePattern { |
| 294 | using RewritePattern::RewritePattern; |
| 295 | |
| 296 | /// Wrapper around the RewritePattern method that passes the derived op type. |
| 297 | LogicalResult matchAndRewrite(Operation *op, |
| 298 | PatternRewriter &rewriter) const final { |
| 299 | return matchAndRewrite(cast<SourceOp>(op), rewriter); |
| 300 | } |
| 301 | |
| 302 | /// Method that operates on the SourceOp type. Must be overridden by the |
| 303 | /// derived pattern class. |
| 304 | virtual LogicalResult matchAndRewrite(SourceOp op, |
| 305 | PatternRewriter &rewriter) const = 0; |
| 306 | }; |
| 307 | } // namespace detail |
| 308 | |
| 309 | /// OpRewritePattern is a wrapper around RewritePattern that allows for |
| 310 | /// matching and rewriting against an instance of a derived operation class as |
| 311 | /// opposed to a raw Operation. |
| 312 | template <typename SourceOp> |
| 313 | struct OpRewritePattern |
| 314 | : public detail::OpOrInterfaceRewritePatternBase<SourceOp> { |
| 315 | |
| 316 | /// Patterns must specify the root operation name they match against, and can |
| 317 | /// also specify the benefit of the pattern matching and a list of generated |
| 318 | /// ops. |
| 319 | OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1, |
| 320 | ArrayRef<StringRef> generatedNames = {}) |
| 321 | : detail::OpOrInterfaceRewritePatternBase<SourceOp>( |
| 322 | SourceOp::getOperationName(), benefit, context, generatedNames) {} |
| 323 | }; |
| 324 | |
| 325 | /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for |
| 326 | /// matching and rewriting against an instance of an operation interface instead |
| 327 | /// of a raw Operation. |
| 328 | template <typename SourceOp> |
| 329 | struct OpInterfaceRewritePattern |
| 330 | : public detail::OpOrInterfaceRewritePatternBase<SourceOp> { |
| 331 | |
| 332 | OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) |
| 333 | : detail::OpOrInterfaceRewritePatternBase<SourceOp>( |
| 334 | Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), |
| 335 | benefit, context) {} |
| 336 | }; |
| 337 | |
| 338 | /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for |
| 339 | /// matching and rewriting against instances of an operation that possess a |
| 340 | /// given trait. |
| 341 | template <template <typename> class TraitType> |
| 342 | class OpTraitRewritePattern : public RewritePattern { |
| 343 | public: |
| 344 | OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) |
| 345 | : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(), |
| 346 | benefit, context) {} |
| 347 | }; |
| 348 | |
| 349 | //===----------------------------------------------------------------------===// |
| 350 | // RewriterBase |
| 351 | //===----------------------------------------------------------------------===// |
| 352 | |
| 353 | /// This class coordinates the application of a rewrite on a set of IR, |
| 354 | /// providing a way for clients to track mutations and create new operations. |
| 355 | /// This class serves as a common API for IR mutation between pattern rewrites |
| 356 | /// and non-pattern rewrites, and facilitates the development of shared |
| 357 | /// IR transformation utilities. |
| 358 | class RewriterBase : public OpBuilder { |
| 359 | public: |
| 360 | struct Listener : public OpBuilder::Listener { |
| 361 | Listener() |
| 362 | : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {} |
| 363 | |
| 364 | /// Notify the listener that the specified block is about to be erased. |
| 365 | /// At this point, the block has zero uses. |
| 366 | virtual void notifyBlockErased(Block *block) {} |
| 367 | |
| 368 | /// Notify the listener that the specified operation was modified in-place. |
| 369 | virtual void notifyOperationModified(Operation *op) {} |
| 370 | |
| 371 | /// Notify the listener that all uses of the specified operation's results |
| 372 | /// are about to be replaced with the results of another operation. This is |
| 373 | /// called before the uses of the old operation have been changed. |
| 374 | /// |
| 375 | /// By default, this function calls the "operation replaced with values" |
| 376 | /// notification. |
| 377 | virtual void notifyOperationReplaced(Operation *op, |
| 378 | Operation *replacement) { |
| 379 | notifyOperationReplaced(op, replacement: replacement->getResults()); |
| 380 | } |
| 381 | |
| 382 | /// Notify the listener that all uses of the specified operation's results |
| 383 | /// are about to be replaced with the a range of values, potentially |
| 384 | /// produced by other operations. This is called before the uses of the |
| 385 | /// operation have been changed. |
| 386 | virtual void notifyOperationReplaced(Operation *op, |
| 387 | ValueRange replacement) {} |
| 388 | |
| 389 | /// Notify the listener that the specified operation is about to be erased. |
| 390 | /// At this point, the operation has zero uses. |
| 391 | /// |
| 392 | /// Note: This notification is not triggered when unlinking an operation. |
| 393 | virtual void notifyOperationErased(Operation *op) {} |
| 394 | |
| 395 | /// Notify the listener that the specified pattern is about to be applied |
| 396 | /// at the specified root operation. |
| 397 | virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {} |
| 398 | |
| 399 | /// Notify the listener that a pattern application finished with the |
| 400 | /// specified status. "success" indicates that the pattern was applied |
| 401 | /// successfully. "failure" indicates that the pattern could not be |
| 402 | /// applied. The pattern may have communicated the reason for the failure |
| 403 | /// with `notifyMatchFailure`. |
| 404 | virtual void notifyPatternEnd(const Pattern &pattern, |
| 405 | LogicalResult status) {} |
| 406 | |
| 407 | /// Notify the listener that the pattern failed to match, and provide a |
| 408 | /// callback to populate a diagnostic with the reason why the failure |
| 409 | /// occurred. This method allows for derived listeners to optionally hook |
| 410 | /// into the reason why a rewrite failed, and display it to users. |
| 411 | virtual void |
| 412 | notifyMatchFailure(Location loc, |
| 413 | function_ref<void(Diagnostic &)> reasonCallback) {} |
| 414 | |
| 415 | static bool classof(const OpBuilder::Listener *base); |
| 416 | }; |
| 417 | |
| 418 | /// A listener that forwards all notifications to another listener. This |
| 419 | /// struct can be used as a base to create listener chains, so that multiple |
| 420 | /// listeners can be notified of IR changes. |
| 421 | struct ForwardingListener : public RewriterBase::Listener { |
| 422 | ForwardingListener(OpBuilder::Listener *listener) |
| 423 | : listener(listener), |
| 424 | rewriteListener( |
| 425 | dyn_cast_if_present<RewriterBase::Listener>(Val: listener)) {} |
| 426 | |
| 427 | void notifyOperationInserted(Operation *op, InsertPoint previous) override { |
| 428 | if (listener) |
| 429 | listener->notifyOperationInserted(op, previous); |
| 430 | } |
| 431 | void notifyBlockInserted(Block *block, Region *previous, |
| 432 | Region::iterator previousIt) override { |
| 433 | if (listener) |
| 434 | listener->notifyBlockInserted(block, previous, previousIt); |
| 435 | } |
| 436 | void notifyBlockErased(Block *block) override { |
| 437 | if (rewriteListener) |
| 438 | rewriteListener->notifyBlockErased(block); |
| 439 | } |
| 440 | void notifyOperationModified(Operation *op) override { |
| 441 | if (rewriteListener) |
| 442 | rewriteListener->notifyOperationModified(op); |
| 443 | } |
| 444 | void notifyOperationReplaced(Operation *op, Operation *newOp) override { |
| 445 | if (rewriteListener) |
| 446 | rewriteListener->notifyOperationReplaced(op, replacement: newOp); |
| 447 | } |
| 448 | void notifyOperationReplaced(Operation *op, |
| 449 | ValueRange replacement) override { |
| 450 | if (rewriteListener) |
| 451 | rewriteListener->notifyOperationReplaced(op, replacement); |
| 452 | } |
| 453 | void notifyOperationErased(Operation *op) override { |
| 454 | if (rewriteListener) |
| 455 | rewriteListener->notifyOperationErased(op); |
| 456 | } |
| 457 | void notifyPatternBegin(const Pattern &pattern, Operation *op) override { |
| 458 | if (rewriteListener) |
| 459 | rewriteListener->notifyPatternBegin(pattern, op); |
| 460 | } |
| 461 | void notifyPatternEnd(const Pattern &pattern, |
| 462 | LogicalResult status) override { |
| 463 | if (rewriteListener) |
| 464 | rewriteListener->notifyPatternEnd(pattern, status); |
| 465 | } |
| 466 | void notifyMatchFailure( |
| 467 | Location loc, |
| 468 | function_ref<void(Diagnostic &)> reasonCallback) override { |
| 469 | if (rewriteListener) |
| 470 | rewriteListener->notifyMatchFailure(loc, reasonCallback); |
| 471 | } |
| 472 | |
| 473 | private: |
| 474 | OpBuilder::Listener *listener; |
| 475 | RewriterBase::Listener *rewriteListener; |
| 476 | }; |
| 477 | |
| 478 | /// Move the blocks that belong to "region" before the given position in |
| 479 | /// another region "parent". The two regions must be different. The caller |
| 480 | /// is responsible for creating or updating the operation transferring flow |
| 481 | /// of control to the region and passing it the correct block arguments. |
| 482 | void inlineRegionBefore(Region ®ion, Region &parent, |
| 483 | Region::iterator before); |
| 484 | void inlineRegionBefore(Region ®ion, Block *before); |
| 485 | |
| 486 | /// Replace the results of the given (original) operation with the specified |
| 487 | /// list of values (replacements). The result types of the given op and the |
| 488 | /// replacements must match. The original op is erased. |
| 489 | virtual void replaceOp(Operation *op, ValueRange newValues); |
| 490 | |
| 491 | /// Replace the results of the given (original) operation with the specified |
| 492 | /// new op (replacement). The result types of the two ops must match. The |
| 493 | /// original op is erased. |
| 494 | virtual void replaceOp(Operation *op, Operation *newOp); |
| 495 | |
| 496 | /// Replace the results of the given (original) op with a new op that is |
| 497 | /// created without verification (replacement). The result values of the two |
| 498 | /// ops must match. The original op is erased. |
| 499 | template <typename OpTy, typename... Args> |
| 500 | OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { |
| 501 | auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); |
| 502 | replaceOp(op, newOp.getOperation()); |
| 503 | return newOp; |
| 504 | } |
| 505 | |
| 506 | /// This method erases an operation that is known to have no uses. |
| 507 | virtual void eraseOp(Operation *op); |
| 508 | |
| 509 | /// This method erases all operations in a block. |
| 510 | virtual void eraseBlock(Block *block); |
| 511 | |
| 512 | /// Inline the operations of block 'source' into block 'dest' before the given |
| 513 | /// position. The source block will be deleted and must have no uses. |
| 514 | /// 'argValues' is used to replace the block arguments of 'source'. |
| 515 | /// |
| 516 | /// If the source block is inserted at the end of the dest block, the dest |
| 517 | /// block must have no successors. Similarly, if the source block is inserted |
| 518 | /// somewhere in the middle (or beginning) of the dest block, the source block |
| 519 | /// must have no successors. Otherwise, the resulting IR would have |
| 520 | /// unreachable operations. |
| 521 | virtual void inlineBlockBefore(Block *source, Block *dest, |
| 522 | Block::iterator before, |
| 523 | ValueRange argValues = std::nullopt); |
| 524 | |
| 525 | /// Inline the operations of block 'source' before the operation 'op'. The |
| 526 | /// source block will be deleted and must have no uses. 'argValues' is used to |
| 527 | /// replace the block arguments of 'source' |
| 528 | /// |
| 529 | /// The source block must have no successors. Otherwise, the resulting IR |
| 530 | /// would have unreachable operations. |
| 531 | void inlineBlockBefore(Block *source, Operation *op, |
| 532 | ValueRange argValues = std::nullopt); |
| 533 | |
| 534 | /// Inline the operations of block 'source' into the end of block 'dest'. The |
| 535 | /// source block will be deleted and must have no uses. 'argValues' is used to |
| 536 | /// replace the block arguments of 'source' |
| 537 | /// |
| 538 | /// The dest block must have no successors. Otherwise, the resulting IR would |
| 539 | /// have unreachable operation. |
| 540 | void mergeBlocks(Block *source, Block *dest, |
| 541 | ValueRange argValues = std::nullopt); |
| 542 | |
| 543 | /// Split the operations starting at "before" (inclusive) out of the given |
| 544 | /// block into a new block, and return it. |
| 545 | Block *splitBlock(Block *block, Block::iterator before); |
| 546 | |
| 547 | /// Unlink this operation from its current block and insert it right before |
| 548 | /// `existingOp` which may be in the same or another block in the same |
| 549 | /// function. |
| 550 | void moveOpBefore(Operation *op, Operation *existingOp); |
| 551 | |
| 552 | /// Unlink this operation from its current block and insert it right before |
| 553 | /// `iterator` in the specified block. |
| 554 | void moveOpBefore(Operation *op, Block *block, Block::iterator iterator); |
| 555 | |
| 556 | /// Unlink this operation from its current block and insert it right after |
| 557 | /// `existingOp` which may be in the same or another block in the same |
| 558 | /// function. |
| 559 | void moveOpAfter(Operation *op, Operation *existingOp); |
| 560 | |
| 561 | /// Unlink this operation from its current block and insert it right after |
| 562 | /// `iterator` in the specified block. |
| 563 | void moveOpAfter(Operation *op, Block *block, Block::iterator iterator); |
| 564 | |
| 565 | /// Unlink this block and insert it right before `existingBlock`. |
| 566 | void moveBlockBefore(Block *block, Block *anotherBlock); |
| 567 | |
| 568 | /// Unlink this block and insert it right before the location that the given |
| 569 | /// iterator points to in the given region. |
| 570 | void moveBlockBefore(Block *block, Region *region, Region::iterator iterator); |
| 571 | |
| 572 | /// This method is used to notify the rewriter that an in-place operation |
| 573 | /// modification is about to happen. A call to this function *must* be |
| 574 | /// followed by a call to either `finalizeOpModification` or |
| 575 | /// `cancelOpModification`. This is a minor efficiency win (it avoids creating |
| 576 | /// a new operation and removing the old one) but also often allows simpler |
| 577 | /// code in the client. |
| 578 | virtual void startOpModification(Operation *op) {} |
| 579 | |
| 580 | /// This method is used to signal the end of an in-place modification of the |
| 581 | /// given operation. This can only be called on operations that were provided |
| 582 | /// to a call to `startOpModification`. |
| 583 | virtual void finalizeOpModification(Operation *op); |
| 584 | |
| 585 | /// This method cancels a pending in-place modification. This can only be |
| 586 | /// called on operations that were provided to a call to |
| 587 | /// `startOpModification`. |
| 588 | virtual void cancelOpModification(Operation *op) {} |
| 589 | |
| 590 | /// This method is a utility wrapper around an in-place modification of an |
| 591 | /// operation. It wraps calls to `startOpModification` and |
| 592 | /// `finalizeOpModification` around the given callable. |
| 593 | template <typename CallableT> |
| 594 | void modifyOpInPlace(Operation *root, CallableT &&callable) { |
| 595 | startOpModification(op: root); |
| 596 | callable(); |
| 597 | finalizeOpModification(op: root); |
| 598 | } |
| 599 | |
| 600 | /// Find uses of `from` and replace them with `to`. Also notify the listener |
| 601 | /// about every in-place op modification (for every use that was replaced). |
| 602 | void replaceAllUsesWith(Value from, Value to) { |
| 603 | for (OpOperand &operand : llvm::make_early_inc_range(Range: from.getUses())) { |
| 604 | Operation *op = operand.getOwner(); |
| 605 | modifyOpInPlace(root: op, callable: [&]() { operand.set(to); }); |
| 606 | } |
| 607 | } |
| 608 | void replaceAllUsesWith(Block *from, Block *to) { |
| 609 | for (BlockOperand &operand : llvm::make_early_inc_range(Range: from->getUses())) { |
| 610 | Operation *op = operand.getOwner(); |
| 611 | modifyOpInPlace(root: op, callable: [&]() { operand.set(to); }); |
| 612 | } |
| 613 | } |
| 614 | void replaceAllUsesWith(ValueRange from, ValueRange to) { |
| 615 | assert(from.size() == to.size() && "incorrect number of replacements" ); |
| 616 | for (auto it : llvm::zip(t&: from, u&: to)) |
| 617 | replaceAllUsesWith(from: std::get<0>(t&: it), to: std::get<1>(t&: it)); |
| 618 | } |
| 619 | |
| 620 | /// Find uses of `from` and replace them with `to`. Also notify the listener |
| 621 | /// about every in-place op modification (for every use that was replaced) |
| 622 | /// and that the `from` operation is about to be replaced. |
| 623 | /// |
| 624 | /// Note: This function cannot be called `replaceAllUsesWith` because the |
| 625 | /// overload resolution, when called with an op that can be implicitly |
| 626 | /// converted to a Value, would be ambiguous. |
| 627 | void replaceAllOpUsesWith(Operation *from, ValueRange to); |
| 628 | void replaceAllOpUsesWith(Operation *from, Operation *to); |
| 629 | |
| 630 | /// Find uses of `from` and replace them with `to` if the `functor` returns |
| 631 | /// true. Also notify the listener about every in-place op modification (for |
| 632 | /// every use that was replaced). The optional `allUsesReplaced` flag is set |
| 633 | /// to "true" if all uses were replaced. |
| 634 | void replaceUsesWithIf(Value from, Value to, |
| 635 | function_ref<bool(OpOperand &)> functor, |
| 636 | bool *allUsesReplaced = nullptr); |
| 637 | void replaceUsesWithIf(ValueRange from, ValueRange to, |
| 638 | function_ref<bool(OpOperand &)> functor, |
| 639 | bool *allUsesReplaced = nullptr); |
| 640 | // Note: This function cannot be called `replaceOpUsesWithIf` because the |
| 641 | // overload resolution, when called with an op that can be implicitly |
| 642 | // converted to a Value, would be ambiguous. |
| 643 | void replaceOpUsesWithIf(Operation *from, ValueRange to, |
| 644 | function_ref<bool(OpOperand &)> functor, |
| 645 | bool *allUsesReplaced = nullptr) { |
| 646 | replaceUsesWithIf(from: from->getResults(), to, functor, allUsesReplaced); |
| 647 | } |
| 648 | |
| 649 | /// Find uses of `from` within `block` and replace them with `to`. Also notify |
| 650 | /// the listener about every in-place op modification (for every use that was |
| 651 | /// replaced). The optional `allUsesReplaced` flag is set to "true" if all |
| 652 | /// uses were replaced. |
| 653 | void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, |
| 654 | Block *block, bool *allUsesReplaced = nullptr) { |
| 655 | replaceOpUsesWithIf( |
| 656 | from: op, to: newValues, |
| 657 | functor: [block](OpOperand &use) { |
| 658 | return block->getParentOp()->isProperAncestor(other: use.getOwner()); |
| 659 | }, |
| 660 | allUsesReplaced); |
| 661 | } |
| 662 | |
| 663 | /// Find uses of `from` and replace them with `to` except if the user is |
| 664 | /// `exceptedUser`. Also notify the listener about every in-place op |
| 665 | /// modification (for every use that was replaced). |
| 666 | void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) { |
| 667 | return replaceUsesWithIf(from, to, functor: [&](OpOperand &use) { |
| 668 | Operation *user = use.getOwner(); |
| 669 | return user != exceptedUser; |
| 670 | }); |
| 671 | } |
| 672 | void replaceAllUsesExcept(Value from, Value to, |
| 673 | const SmallPtrSetImpl<Operation *> &preservedUsers); |
| 674 | |
| 675 | /// Used to notify the listener that the IR failed to be rewritten because of |
| 676 | /// a match failure, and provide a callback to populate a diagnostic with the |
| 677 | /// reason why the failure occurred. This method allows for derived rewriters |
| 678 | /// to optionally hook into the reason why a rewrite failed, and display it to |
| 679 | /// users. |
| 680 | template <typename CallbackT> |
| 681 | std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> |
| 682 | notifyMatchFailure(Location loc, CallbackT &&reasonCallback) { |
| 683 | if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener)) |
| 684 | rewriteListener->notifyMatchFailure( |
| 685 | loc, reasonCallback: function_ref<void(Diagnostic &)>(reasonCallback)); |
| 686 | return failure(); |
| 687 | } |
| 688 | template <typename CallbackT> |
| 689 | std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> |
| 690 | notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { |
| 691 | if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener)) |
| 692 | rewriteListener->notifyMatchFailure( |
| 693 | loc: op->getLoc(), reasonCallback: function_ref<void(Diagnostic &)>(reasonCallback)); |
| 694 | return failure(); |
| 695 | } |
| 696 | template <typename ArgT> |
| 697 | LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) { |
| 698 | return notifyMatchFailure(std::forward<ArgT>(arg), |
| 699 | [&](Diagnostic &diag) { diag << msg; }); |
| 700 | } |
| 701 | template <typename ArgT> |
| 702 | LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) { |
| 703 | return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg)); |
| 704 | } |
| 705 | |
| 706 | protected: |
| 707 | /// Initialize the builder. |
| 708 | explicit RewriterBase(MLIRContext *ctx, |
| 709 | OpBuilder::Listener *listener = nullptr) |
| 710 | : OpBuilder(ctx, listener) {} |
| 711 | explicit RewriterBase(const OpBuilder &otherBuilder) |
| 712 | : OpBuilder(otherBuilder) {} |
| 713 | explicit RewriterBase(Operation *op, OpBuilder::Listener *listener = nullptr) |
| 714 | : OpBuilder(op, listener) {} |
| 715 | virtual ~RewriterBase(); |
| 716 | |
| 717 | private: |
| 718 | void operator=(const RewriterBase &) = delete; |
| 719 | RewriterBase(const RewriterBase &) = delete; |
| 720 | }; |
| 721 | |
| 722 | //===----------------------------------------------------------------------===// |
| 723 | // IRRewriter |
| 724 | //===----------------------------------------------------------------------===// |
| 725 | |
| 726 | /// This class coordinates rewriting a piece of IR outside of a pattern rewrite, |
| 727 | /// providing a way to keep track of the mutations made to the IR. This class |
| 728 | /// should only be used in situations where another `RewriterBase` instance, |
| 729 | /// such as a `PatternRewriter`, is not available. |
| 730 | class IRRewriter : public RewriterBase { |
| 731 | public: |
| 732 | explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr) |
| 733 | : RewriterBase(ctx, listener) {} |
| 734 | explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} |
| 735 | explicit IRRewriter(Operation *op, OpBuilder::Listener *listener = nullptr) |
| 736 | : RewriterBase(op, listener) {} |
| 737 | }; |
| 738 | |
| 739 | //===----------------------------------------------------------------------===// |
| 740 | // PatternRewriter |
| 741 | //===----------------------------------------------------------------------===// |
| 742 | |
| 743 | /// A special type of `RewriterBase` that coordinates the application of a |
| 744 | /// rewrite pattern on the current IR being matched, providing a way to keep |
| 745 | /// track of any mutations made. This class should be used to perform all |
| 746 | /// necessary IR mutations within a rewrite pattern, as the pattern driver may |
| 747 | /// be tracking various state that would be invalidated when a mutation takes |
| 748 | /// place. |
| 749 | class PatternRewriter : public RewriterBase { |
| 750 | public: |
| 751 | explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} |
| 752 | using RewriterBase::RewriterBase; |
| 753 | |
| 754 | /// A hook used to indicate if the pattern rewriter can recover from failure |
| 755 | /// during the rewrite stage of a pattern. For example, if the pattern |
| 756 | /// rewriter supports rollback, it may progress smoothly even if IR was |
| 757 | /// changed during the rewrite. |
| 758 | virtual bool canRecoverFromRewriteFailure() const { return false; } |
| 759 | }; |
| 760 | |
| 761 | } // namespace mlir |
| 762 | |
| 763 | // Optionally expose PDL pattern matching methods. |
| 764 | #include "PDLPatternMatch.h.inc" |
| 765 | |
| 766 | namespace mlir { |
| 767 | |
| 768 | //===----------------------------------------------------------------------===// |
| 769 | // RewritePatternSet |
| 770 | //===----------------------------------------------------------------------===// |
| 771 | |
| 772 | class RewritePatternSet { |
| 773 | using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; |
| 774 | |
| 775 | public: |
| 776 | RewritePatternSet(MLIRContext *context) : context(context) {} |
| 777 | |
| 778 | /// Construct a RewritePatternSet populated with the given pattern. |
| 779 | RewritePatternSet(MLIRContext *context, |
| 780 | std::unique_ptr<RewritePattern> pattern) |
| 781 | : context(context) { |
| 782 | nativePatterns.emplace_back(args: std::move(pattern)); |
| 783 | } |
| 784 | RewritePatternSet(PDLPatternModule &&pattern) |
| 785 | : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {} |
| 786 | |
| 787 | MLIRContext *getContext() const { return context; } |
| 788 | |
| 789 | /// Return the native patterns held in this list. |
| 790 | NativePatternListT &getNativePatterns() { return nativePatterns; } |
| 791 | |
| 792 | /// Return the PDL patterns held in this list. |
| 793 | PDLPatternModule &getPDLPatterns() { return pdlPatterns; } |
| 794 | |
| 795 | /// Clear out all of the held patterns in this list. |
| 796 | void clear() { |
| 797 | nativePatterns.clear(); |
| 798 | pdlPatterns.clear(); |
| 799 | } |
| 800 | |
| 801 | //===--------------------------------------------------------------------===// |
| 802 | // 'add' methods for adding patterns to the set. |
| 803 | //===--------------------------------------------------------------------===// |
| 804 | |
| 805 | /// Add an instance of each of the pattern types 'Ts' to the pattern list with |
| 806 | /// the given arguments. Return a reference to `this` for chaining insertions. |
| 807 | /// Note: ConstructorArg is necessary here to separate the two variadic lists. |
| 808 | template <typename... Ts, typename ConstructorArg, |
| 809 | typename... ConstructorArgs, |
| 810 | typename = std::enable_if_t<sizeof...(Ts) != 0>> |
| 811 | RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) { |
| 812 | // The following expands a call to emplace_back for each of the pattern |
| 813 | // types 'Ts'. |
| 814 | (addImpl<Ts>(/*debugLabels=*/std::nullopt, |
| 815 | std::forward<ConstructorArg>(arg), |
| 816 | std::forward<ConstructorArgs>(args)...), |
| 817 | ...); |
| 818 | return *this; |
| 819 | } |
| 820 | /// An overload of the above `add` method that allows for attaching a set |
| 821 | /// of debug labels to the attached patterns. This is useful for labeling |
| 822 | /// groups of patterns that may be shared between multiple different |
| 823 | /// passes/users. |
| 824 | template <typename... Ts, typename ConstructorArg, |
| 825 | typename... ConstructorArgs, |
| 826 | typename = std::enable_if_t<sizeof...(Ts) != 0>> |
| 827 | RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels, |
| 828 | ConstructorArg &&arg, |
| 829 | ConstructorArgs &&...args) { |
| 830 | // The following expands a call to emplace_back for each of the pattern |
| 831 | // types 'Ts'. |
| 832 | (addImpl<Ts>(debugLabels, arg, args...), ...); |
| 833 | return *this; |
| 834 | } |
| 835 | |
| 836 | /// Add an instance of each of the pattern types 'Ts'. Return a reference to |
| 837 | /// `this` for chaining insertions. |
| 838 | template <typename... Ts> |
| 839 | RewritePatternSet &add() { |
| 840 | (addImpl<Ts>(), ...); |
| 841 | return *this; |
| 842 | } |
| 843 | |
| 844 | /// Add the given native pattern to the pattern list. Return a reference to |
| 845 | /// `this` for chaining insertions. |
| 846 | RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) { |
| 847 | nativePatterns.emplace_back(args: std::move(pattern)); |
| 848 | return *this; |
| 849 | } |
| 850 | |
| 851 | /// Add the given PDL pattern to the pattern list. Return a reference to |
| 852 | /// `this` for chaining insertions. |
| 853 | RewritePatternSet &add(PDLPatternModule &&pattern) { |
| 854 | pdlPatterns.mergeIn(std::move(pattern)); |
| 855 | return *this; |
| 856 | } |
| 857 | |
| 858 | // Add a matchAndRewrite style pattern represented as a C function pointer. |
| 859 | template <typename OpType> |
| 860 | RewritePatternSet & |
| 861 | add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), |
| 862 | PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) { |
| 863 | struct FnPattern final : public OpRewritePattern<OpType> { |
| 864 | FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), |
| 865 | MLIRContext *context, PatternBenefit benefit, |
| 866 | ArrayRef<StringRef> generatedNames) |
| 867 | : OpRewritePattern<OpType>(context, benefit, generatedNames), |
| 868 | implFn(implFn) {} |
| 869 | |
| 870 | LogicalResult matchAndRewrite(OpType op, |
| 871 | PatternRewriter &rewriter) const override { |
| 872 | return implFn(op, rewriter); |
| 873 | } |
| 874 | |
| 875 | private: |
| 876 | LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); |
| 877 | }; |
| 878 | add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit, |
| 879 | generatedNames)); |
| 880 | return *this; |
| 881 | } |
| 882 | |
| 883 | //===--------------------------------------------------------------------===// |
| 884 | // Pattern Insertion |
| 885 | //===--------------------------------------------------------------------===// |
| 886 | |
| 887 | // TODO: These are soft deprecated in favor of the 'add' methods above. |
| 888 | |
| 889 | /// Add an instance of each of the pattern types 'Ts' to the pattern list with |
| 890 | /// the given arguments. Return a reference to `this` for chaining insertions. |
| 891 | /// Note: ConstructorArg is necessary here to separate the two variadic lists. |
| 892 | template <typename... Ts, typename ConstructorArg, |
| 893 | typename... ConstructorArgs, |
| 894 | typename = std::enable_if_t<sizeof...(Ts) != 0>> |
| 895 | RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) { |
| 896 | // The following expands a call to emplace_back for each of the pattern |
| 897 | // types 'Ts'. |
| 898 | (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...); |
| 899 | return *this; |
| 900 | } |
| 901 | |
| 902 | /// Add an instance of each of the pattern types 'Ts'. Return a reference to |
| 903 | /// `this` for chaining insertions. |
| 904 | template <typename... Ts> |
| 905 | RewritePatternSet &insert() { |
| 906 | (addImpl<Ts>(), ...); |
| 907 | return *this; |
| 908 | } |
| 909 | |
| 910 | /// Add the given native pattern to the pattern list. Return a reference to |
| 911 | /// `this` for chaining insertions. |
| 912 | RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) { |
| 913 | nativePatterns.emplace_back(args: std::move(pattern)); |
| 914 | return *this; |
| 915 | } |
| 916 | |
| 917 | /// Add the given PDL pattern to the pattern list. Return a reference to |
| 918 | /// `this` for chaining insertions. |
| 919 | RewritePatternSet &insert(PDLPatternModule &&pattern) { |
| 920 | pdlPatterns.mergeIn(std::move(pattern)); |
| 921 | return *this; |
| 922 | } |
| 923 | |
| 924 | // Add a matchAndRewrite style pattern represented as a C function pointer. |
| 925 | template <typename OpType> |
| 926 | RewritePatternSet & |
| 927 | insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) { |
| 928 | struct FnPattern final : public OpRewritePattern<OpType> { |
| 929 | FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), |
| 930 | MLIRContext *context) |
| 931 | : OpRewritePattern<OpType>(context), implFn(implFn) { |
| 932 | this->setDebugName(llvm::getTypeName<FnPattern>()); |
| 933 | } |
| 934 | |
| 935 | LogicalResult matchAndRewrite(OpType op, |
| 936 | PatternRewriter &rewriter) const override { |
| 937 | return implFn(op, rewriter); |
| 938 | } |
| 939 | |
| 940 | private: |
| 941 | LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); |
| 942 | }; |
| 943 | add(std::make_unique<FnPattern>(std::move(implFn), getContext())); |
| 944 | return *this; |
| 945 | } |
| 946 | |
| 947 | private: |
| 948 | /// Add an instance of the pattern type 'T'. Return a reference to `this` for |
| 949 | /// chaining insertions. |
| 950 | template <typename T, typename... Args> |
| 951 | std::enable_if_t<std::is_base_of<RewritePattern, T>::value> |
| 952 | addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) { |
| 953 | std::unique_ptr<T> pattern = |
| 954 | RewritePattern::create<T>(std::forward<Args>(args)...); |
| 955 | pattern->addDebugLabels(debugLabels); |
| 956 | nativePatterns.emplace_back(std::move(pattern)); |
| 957 | } |
| 958 | |
| 959 | template <typename T, typename... Args> |
| 960 | std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value> |
| 961 | addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) { |
| 962 | // TODO: Add the provided labels to the PDL pattern when PDL supports |
| 963 | // labels. |
| 964 | pdlPatterns.mergeIn(T(std::forward<Args>(args)...)); |
| 965 | } |
| 966 | |
| 967 | MLIRContext *const context; |
| 968 | NativePatternListT nativePatterns; |
| 969 | |
| 970 | // Patterns expressed with PDL. This will compile to a stub class when PDL is |
| 971 | // not enabled. |
| 972 | PDLPatternModule pdlPatterns; |
| 973 | }; |
| 974 | |
| 975 | } // namespace mlir |
| 976 | |
| 977 | #endif // MLIR_IR_PATTERNMATCH_H |
| 978 | |