1//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements mlir::applyPatternsGreedily.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14
15#include "mlir/Config/mlir-config.h"
16#include "mlir/IR/Action.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Verifier.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Rewrite/PatternApplicator.h"
21#include "mlir/Transforms/FoldUtils.h"
22#include "mlir/Transforms/RegionUtils.h"
23#include "llvm/ADT/BitVector.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/ScopeExit.h"
26#include "llvm/Support/Debug.h"
27#include "llvm/Support/ScopedPrinter.h"
28#include "llvm/Support/raw_ostream.h"
29
30#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
31#include <random>
32#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
33
34using namespace mlir;
35
36#define DEBUG_TYPE "greedy-rewriter"
37
38namespace {
39
40//===----------------------------------------------------------------------===//
41// Debugging Infrastructure
42//===----------------------------------------------------------------------===//
43
44#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
45/// A helper struct that performs various "expensive checks" to detect broken
46/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
47/// broken if:
48/// * IR does not verify after pattern application / folding.
49/// * Pattern returns "failure" but the IR has changed.
50/// * Pattern returns "success" but the IR has not changed.
51///
52/// This struct stores finger prints of ops to determine whether the IR has
53/// changed or not.
54struct ExpensiveChecks : public RewriterBase::ForwardingListener {
55 ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
56 : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
57
58 /// Compute finger prints of the given op and its nested ops.
59 void computeFingerPrints(Operation *topLevel) {
60 this->topLevel = topLevel;
61 this->topLevelFingerPrint.emplace(topLevel);
62 topLevel->walk([&](Operation *op) {
63 fingerprints.try_emplace(op, op, /*includeNested=*/false);
64 });
65 }
66
67 /// Clear all finger prints.
68 void clear() {
69 topLevel = nullptr;
70 topLevelFingerPrint.reset();
71 fingerprints.clear();
72 }
73
74 void notifyRewriteSuccess() {
75 if (!topLevel)
76 return;
77
78 // Make sure that the IR still verifies.
79 if (failed(verify(topLevel)))
80 llvm::report_fatal_error("IR failed to verify after pattern application");
81
82 // Pattern application success => IR must have changed.
83 OperationFingerPrint afterFingerPrint(topLevel);
84 if (*topLevelFingerPrint == afterFingerPrint) {
85 // Note: Run "mlir-opt -debug" to see which pattern is broken.
86 llvm::report_fatal_error(
87 "pattern returned success but IR did not change");
88 }
89 for (const auto &it : fingerprints) {
90 // Skip top-level op, its finger print is never invalidated.
91 if (it.first == topLevel)
92 continue;
93 // Note: Finger print computation may crash when an op was erased
94 // without notifying the rewriter. (Run with ASAN to see where the op was
95 // erased; the op was probably erased directly, bypassing the rewriter
96 // API.) Finger print computation does may not crash if a new op was
97 // created at the same memory location. (But then the finger print should
98 // have changed.)
99 if (it.second !=
100 OperationFingerPrint(it.first, /*includeNested=*/false)) {
101 // Note: Run "mlir-opt -debug" to see which pattern is broken.
102 llvm::report_fatal_error("operation finger print changed");
103 }
104 }
105 }
106
107 void notifyRewriteFailure() {
108 if (!topLevel)
109 return;
110
111 // Pattern application failure => IR must not have changed.
112 OperationFingerPrint afterFingerPrint(topLevel);
113 if (*topLevelFingerPrint != afterFingerPrint) {
114 // Note: Run "mlir-opt -debug" to see which pattern is broken.
115 llvm::report_fatal_error("pattern returned failure but IR did change");
116 }
117 }
118
119 void notifyFoldingSuccess() {
120 if (!topLevel)
121 return;
122
123 // Make sure that the IR still verifies.
124 if (failed(verify(topLevel)))
125 llvm::report_fatal_error("IR failed to verify after folding");
126 }
127
128protected:
129 /// Invalidate the finger print of the given op, i.e., remove it from the map.
130 void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
131
132 void notifyBlockErased(Block *block) override {
133 RewriterBase::ForwardingListener::notifyBlockErased(block);
134
135 // The block structure (number of blocks, types of block arguments, etc.)
136 // is part of the fingerprint of the parent op.
137 // TODO: The parent op fingerprint should also be invalidated when modifying
138 // the block arguments of a block, but we do not have a
139 // `notifyBlockModified` callback yet.
140 invalidateFingerPrint(block->getParentOp());
141 }
142
143 void notifyOperationInserted(Operation *op,
144 OpBuilder::InsertPoint previous) override {
145 RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
146 invalidateFingerPrint(op->getParentOp());
147 }
148
149 void notifyOperationModified(Operation *op) override {
150 RewriterBase::ForwardingListener::notifyOperationModified(op);
151 invalidateFingerPrint(op);
152 }
153
154 void notifyOperationErased(Operation *op) override {
155 RewriterBase::ForwardingListener::notifyOperationErased(op);
156 op->walk([this](Operation *op) { invalidateFingerPrint(op); });
157 }
158
159 /// Operation finger prints to detect invalid pattern API usage. IR is checked
160 /// against these finger prints after pattern application to detect cases
161 /// where IR was modified directly, bypassing the rewriter API.
162 DenseMap<Operation *, OperationFingerPrint> fingerprints;
163
164 /// Top-level operation of the current greedy rewrite.
165 Operation *topLevel = nullptr;
166
167 /// Finger print of the top-level operation.
168 std::optional<OperationFingerPrint> topLevelFingerPrint;
169};
170#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
171
172#ifndef NDEBUG
173static Operation *getDumpRootOp(Operation *op) {
174 // Dump the parent op so that materialized constants are visible. If the op
175 // is a top-level op, dump it directly.
176 if (Operation *parentOp = op->getParentOp())
177 return parentOp;
178 return op;
179}
180static void logSuccessfulFolding(Operation *op) {
181 llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
182 op->dump();
183 llvm::dbgs() << "\n\n";
184}
185#endif // NDEBUG
186
187//===----------------------------------------------------------------------===//
188// Worklist
189//===----------------------------------------------------------------------===//
190
191/// A LIFO worklist of operations with efficient removal and set semantics.
192///
193/// This class maintains a vector of operations and a mapping of operations to
194/// positions in the vector, so that operations can be removed efficiently at
195/// random. When an operation is removed, it is replaced with nullptr. Such
196/// nullptr are skipped when pop'ing elements.
197class Worklist {
198public:
199 Worklist();
200
201 /// Clear the worklist.
202 void clear();
203
204 /// Return whether the worklist is empty.
205 bool empty() const;
206
207 /// Push an operation to the end of the worklist, unless the operation is
208 /// already on the worklist.
209 void push(Operation *op);
210
211 /// Pop the an operation from the end of the worklist. Only allowed on
212 /// non-empty worklists.
213 Operation *pop();
214
215 /// Remove an operation from the worklist.
216 void remove(Operation *op);
217
218 /// Reverse the worklist.
219 void reverse();
220
221protected:
222 /// The worklist of operations.
223 std::vector<Operation *> list;
224
225 /// A mapping of operations to positions in `list`.
226 DenseMap<Operation *, unsigned> map;
227};
228
229Worklist::Worklist() { list.reserve(n: 64); }
230
231void Worklist::clear() {
232 list.clear();
233 map.clear();
234}
235
236bool Worklist::empty() const {
237 // Skip all nullptr.
238 return !llvm::any_of(Range: list,
239 P: [](Operation *op) { return static_cast<bool>(op); });
240}
241
242void Worklist::push(Operation *op) {
243 assert(op && "cannot push nullptr to worklist");
244 // Check to see if the worklist already contains this op.
245 if (!map.insert(KV: {op, list.size()}).second)
246 return;
247 list.push_back(x: op);
248}
249
250Operation *Worklist::pop() {
251 assert(!empty() && "cannot pop from empty worklist");
252 // Skip and remove all trailing nullptr.
253 while (!list.back())
254 list.pop_back();
255 Operation *op = list.back();
256 list.pop_back();
257 map.erase(Val: op);
258 // Cleanup: Remove all trailing nullptr.
259 while (!list.empty() && !list.back())
260 list.pop_back();
261 return op;
262}
263
264void Worklist::remove(Operation *op) {
265 assert(op && "cannot remove nullptr from worklist");
266 auto it = map.find(Val: op);
267 if (it != map.end()) {
268 assert(list[it->second] == op && "malformed worklist data structure");
269 list[it->second] = nullptr;
270 map.erase(I: it);
271 }
272}
273
274void Worklist::reverse() {
275 std::reverse(first: list.begin(), last: list.end());
276 for (size_t i = 0, e = list.size(); i != e; ++i)
277 map[list[i]] = i;
278}
279
280#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
281/// A worklist that pops elements at a random position. This worklist is for
282/// testing/debugging purposes only. It can be used to ensure that lowering
283/// pipelines work correctly regardless of the order in which ops are processed
284/// by the GreedyPatternRewriteDriver.
285class RandomizedWorklist : public Worklist {
286public:
287 RandomizedWorklist() : Worklist() {
288 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
289 }
290
291 /// Pop a random non-empty op from the worklist.
292 Operation *pop() {
293 Operation *op = nullptr;
294 do {
295 assert(!list.empty() && "cannot pop from empty worklist");
296 int64_t pos = generator() % list.size();
297 op = list[pos];
298 list.erase(list.begin() + pos);
299 for (int64_t i = pos, e = list.size(); i < e; ++i)
300 map[list[i]] = i;
301 map.erase(op);
302 } while (!op);
303 return op;
304 }
305
306private:
307 std::minstd_rand0 generator;
308};
309#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
310
311//===----------------------------------------------------------------------===//
312// GreedyPatternRewriteDriver
313//===----------------------------------------------------------------------===//
314
315/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
316/// applies the locally optimal patterns.
317///
318/// This abstract class manages the worklist and contains helper methods for
319/// rewriting ops on the worklist. Derived classes specify how ops are added
320/// to the worklist in the beginning.
321class GreedyPatternRewriteDriver : public RewriterBase::Listener {
322protected:
323 explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
324 const FrozenRewritePatternSet &patterns,
325 const GreedyRewriteConfig &config);
326
327 /// Add the given operation to the worklist.
328 void addSingleOpToWorklist(Operation *op);
329
330 /// Add the given operation and its ancestors to the worklist.
331 void addToWorklist(Operation *op);
332
333 /// Notify the driver that the specified operation may have been modified
334 /// in-place. The operation is added to the worklist.
335 void notifyOperationModified(Operation *op) override;
336
337 /// Notify the driver that the specified operation was inserted. Update the
338 /// worklist as needed: The operation is enqueued depending on scope and
339 /// strict mode.
340 void notifyOperationInserted(Operation *op,
341 OpBuilder::InsertPoint previous) override;
342
343 /// Notify the driver that the specified operation was removed. Update the
344 /// worklist as needed: The operation and its children are removed from the
345 /// worklist.
346 void notifyOperationErased(Operation *op) override;
347
348 /// Notify the driver that the specified operation was replaced. Update the
349 /// worklist as needed: New users are added enqueued.
350 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
351
352 /// Process ops until the worklist is empty or `config.maxNumRewrites` is
353 /// reached. Return `true` if any IR was changed.
354 bool processWorklist();
355
356 /// The pattern rewriter that is used for making IR modifications and is
357 /// passed to rewrite patterns.
358 PatternRewriter rewriter;
359
360 /// The worklist for this transformation keeps track of the operations that
361 /// need to be (re)visited.
362#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
363 RandomizedWorklist worklist;
364#else
365 Worklist worklist;
366#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
367
368 /// Configuration information for how to simplify.
369 const GreedyRewriteConfig config;
370
371 /// The list of ops we are restricting our rewrites to. These include the
372 /// supplied set of ops as well as new ops created while rewriting those ops
373 /// depending on `strictMode`. This set is not maintained when
374 /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
375 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
376
377private:
378 /// Look over the provided operands for any defining operations that should
379 /// be re-added to the worklist. This function should be called when an
380 /// operation is modified or removed, as it may trigger further
381 /// simplifications.
382 void addOperandsToWorklist(Operation *op);
383
384 /// Notify the driver that the given block was inserted.
385 void notifyBlockInserted(Block *block, Region *previous,
386 Region::iterator previousIt) override;
387
388 /// Notify the driver that the given block is about to be removed.
389 void notifyBlockErased(Block *block) override;
390
391 /// For debugging only: Notify the driver of a pattern match failure.
392 void
393 notifyMatchFailure(Location loc,
394 function_ref<void(Diagnostic &)> reasonCallback) override;
395
396#ifndef NDEBUG
397 /// A logger used to emit information during the application process.
398 llvm::ScopedPrinter logger{llvm::dbgs()};
399#endif
400
401 /// The low-level pattern applicator.
402 PatternApplicator matcher;
403
404#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
405 ExpensiveChecks expensiveChecks;
406#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
407};
408} // namespace
409
410GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
411 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
412 const GreedyRewriteConfig &config)
413 : rewriter(ctx), config(config), matcher(patterns)
414#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
415 // clang-format off
416 , expensiveChecks(
417 /*driver=*/this,
418 /*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
419 : nullptr)
420// clang-format on
421#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
422{
423 // Apply a simple cost model based solely on pattern benefit.
424 matcher.applyDefaultCostModel();
425
426 // Set up listener.
427#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
428 // Send IR notifications to the debug handler. This handler will then forward
429 // all notifications to this GreedyPatternRewriteDriver.
430 rewriter.setListener(&expensiveChecks);
431#else
432 rewriter.setListener(this);
433#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
434}
435
436bool GreedyPatternRewriteDriver::processWorklist() {
437#ifndef NDEBUG
438 const char *logLineComment =
439 "//===-------------------------------------------===//\n";
440
441 /// A utility function to log a process result for the given reason.
442 auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
443 logger.unindent();
444 logger.startLine() << "} -> " << result;
445 if (!msg.isTriviallyEmpty())
446 logger.getOStream() << " : " << msg;
447 logger.getOStream() << "\n";
448 };
449 auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
450 logResult(result, msg);
451 logger.startLine() << logLineComment;
452 };
453#endif
454
455 bool changed = false;
456 int64_t numRewrites = 0;
457 while (!worklist.empty() &&
458 (numRewrites < config.getMaxNumRewrites() ||
459 config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
460 auto *op = worklist.pop();
461
462 LLVM_DEBUG({
463 logger.getOStream() << "\n";
464 logger.startLine() << logLineComment;
465 logger.startLine() << "Processing operation : '" << op->getName() << "'("
466 << op << ") {\n";
467 logger.indent();
468
469 // If the operation has no regions, just print it here.
470 if (op->getNumRegions() == 0) {
471 op->print(
472 logger.startLine(),
473 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
474 logger.getOStream() << "\n\n";
475 }
476 });
477
478 // If the operation is trivially dead - remove it.
479 if (isOpTriviallyDead(op)) {
480 rewriter.eraseOp(op);
481 changed = true;
482
483 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
484 continue;
485 }
486
487 // Try to fold this op. Do not fold constant ops. That would lead to an
488 // infinite folding loop, as every constant op would be folded to an
489 // Attribute and then immediately be rematerialized as a constant op, which
490 // is then put on the worklist.
491 if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
492 SmallVector<OpFoldResult> foldResults;
493 if (succeeded(Result: op->fold(results&: foldResults))) {
494 LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
495#ifndef NDEBUG
496 Operation *dumpRootOp = getDumpRootOp(op);
497#endif // NDEBUG
498 if (foldResults.empty()) {
499 // Op was modified in-place.
500 notifyOperationModified(op);
501 changed = true;
502 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
503#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
504 expensiveChecks.notifyFoldingSuccess();
505#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
506 continue;
507 }
508
509 // Op results can be replaced with `foldResults`.
510 assert(foldResults.size() == op->getNumResults() &&
511 "folder produced incorrect number of results");
512 OpBuilder::InsertionGuard g(rewriter);
513 rewriter.setInsertionPoint(op);
514 SmallVector<Value> replacements;
515 bool materializationSucceeded = true;
516 for (auto [ofr, resultType] :
517 llvm::zip_equal(t&: foldResults, u: op->getResultTypes())) {
518 if (auto value = dyn_cast<Value>(Val&: ofr)) {
519 assert(value.getType() == resultType &&
520 "folder produced value of incorrect type");
521 replacements.push_back(Elt: value);
522 continue;
523 }
524 // Materialize Attributes as SSA values.
525 Operation *constOp = op->getDialect()->materializeConstant(
526 builder&: rewriter, value: cast<Attribute>(Val&: ofr), type: resultType, loc: op->getLoc());
527
528 if (!constOp) {
529 // If materialization fails, cleanup any operations generated for
530 // the previous results.
531 llvm::SmallDenseSet<Operation *> replacementOps;
532 for (Value replacement : replacements) {
533 assert(replacement.use_empty() &&
534 "folder reused existing op for one result but constant "
535 "materialization failed for another result");
536 replacementOps.insert(V: replacement.getDefiningOp());
537 }
538 for (Operation *op : replacementOps) {
539 rewriter.eraseOp(op);
540 }
541
542 materializationSucceeded = false;
543 break;
544 }
545
546 assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
547 "materializeConstant produced op that is not a ConstantLike");
548 assert(constOp->getResultTypes()[0] == resultType &&
549 "materializeConstant produced incorrect result type");
550 replacements.push_back(Elt: constOp->getResult(idx: 0));
551 }
552
553 if (materializationSucceeded) {
554 rewriter.replaceOp(op, newValues: replacements);
555 changed = true;
556 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
557#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
558 expensiveChecks.notifyFoldingSuccess();
559#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
560 continue;
561 }
562 }
563 }
564
565 // Try to match one of the patterns. The rewriter is automatically
566 // notified of any necessary changes, so there is nothing else to do
567 // here.
568 auto canApplyCallback = [&](const Pattern &pattern) {
569 LLVM_DEBUG({
570 logger.getOStream() << "\n";
571 logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
572 << op->getName() << " -> (";
573 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
574 logger.getOStream() << ")' {\n";
575 logger.indent();
576 });
577 if (RewriterBase::Listener *listener = config.getListener())
578 listener->notifyPatternBegin(pattern, op);
579 return true;
580 };
581 function_ref<bool(const Pattern &)> canApply = canApplyCallback;
582 auto onFailureCallback = [&](const Pattern &pattern) {
583 LLVM_DEBUG(logResult("failure", "pattern failed to match"));
584 if (RewriterBase::Listener *listener = config.getListener())
585 listener->notifyPatternEnd(pattern, status: failure());
586 };
587 function_ref<void(const Pattern &)> onFailure = onFailureCallback;
588 auto onSuccessCallback = [&](const Pattern &pattern) {
589 LLVM_DEBUG(logResult("success", "pattern applied successfully"));
590 if (RewriterBase::Listener *listener = config.getListener())
591 listener->notifyPatternEnd(pattern, status: success());
592 return success();
593 };
594 function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
595
596#ifdef NDEBUG
597 // Optimization: PatternApplicator callbacks are not needed when running in
598 // optimized mode and without a listener.
599 if (!config.getListener()) {
600 canApply = nullptr;
601 onFailure = nullptr;
602 onSuccess = nullptr;
603 }
604#endif // NDEBUG
605
606#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
607 if (config.getScope()) {
608 expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
609 }
610 auto clearFingerprints =
611 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
612#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
613
614 LogicalResult matchResult =
615 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
616
617 if (succeeded(Result: matchResult)) {
618 LLVM_DEBUG(logResultWithLine("success", "at least one pattern matched"));
619#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
620 expensiveChecks.notifyRewriteSuccess();
621#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
622 changed = true;
623 ++numRewrites;
624 } else {
625 LLVM_DEBUG(logResultWithLine("failure", "all patterns failed to match"));
626#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
627 expensiveChecks.notifyRewriteFailure();
628#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
629 }
630 }
631
632 return changed;
633}
634
635void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
636 assert(op && "expected valid op");
637 // Gather potential ancestors while looking for a "scope" parent region.
638 SmallVector<Operation *, 8> ancestors;
639 Region *region = nullptr;
640 do {
641 ancestors.push_back(Elt: op);
642 region = op->getParentRegion();
643 if (config.getScope() == region) {
644 // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
645 for (Operation *op : ancestors)
646 addSingleOpToWorklist(op);
647 return;
648 }
649 if (region == nullptr)
650 return;
651 } while ((op = region->getParentOp()));
652}
653
654void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
655 if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
656 strictModeFilteredOps.contains(V: op))
657 worklist.push(op);
658}
659
660void GreedyPatternRewriteDriver::notifyBlockInserted(
661 Block *block, Region *previous, Region::iterator previousIt) {
662 if (RewriterBase::Listener *listener = config.getListener())
663 listener->notifyBlockInserted(block, previous, previousIt);
664}
665
666void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
667 if (RewriterBase::Listener *listener = config.getListener())
668 listener->notifyBlockErased(block);
669}
670
671void GreedyPatternRewriteDriver::notifyOperationInserted(
672 Operation *op, OpBuilder::InsertPoint previous) {
673 LLVM_DEBUG({
674 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
675 << ")\n";
676 });
677 if (RewriterBase::Listener *listener = config.getListener())
678 listener->notifyOperationInserted(op, previous);
679 if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
680 strictModeFilteredOps.insert(V: op);
681 addToWorklist(op);
682}
683
684void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
685 LLVM_DEBUG({
686 logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
687 << ")\n";
688 });
689 if (RewriterBase::Listener *listener = config.getListener())
690 listener->notifyOperationModified(op);
691 addToWorklist(op);
692}
693
694void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
695 for (Value operand : op->getOperands()) {
696 // If this operand currently has at most 2 users, add its defining op to the
697 // worklist. Indeed, after the op is deleted, then the operand will have at
698 // most 1 user left. If it has 0 users left, it can be deleted too,
699 // and if it has 1 user left, there may be further canonicalization
700 // opportunities.
701 if (!operand)
702 continue;
703
704 auto *defOp = operand.getDefiningOp();
705 if (!defOp)
706 continue;
707
708 Operation *otherUser = nullptr;
709 bool hasMoreThanTwoUses = false;
710 for (auto user : operand.getUsers()) {
711 if (user == op || user == otherUser)
712 continue;
713 if (!otherUser) {
714 otherUser = user;
715 continue;
716 }
717 hasMoreThanTwoUses = true;
718 break;
719 }
720 if (hasMoreThanTwoUses)
721 continue;
722
723 addToWorklist(op: defOp);
724 }
725}
726
727void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
728 LLVM_DEBUG({
729 logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
730 << ")\n";
731 });
732
733#ifndef NDEBUG
734 // Only ops that are within the configured scope are added to the worklist of
735 // the greedy pattern rewriter. Moreover, the parent op of the scope region is
736 // the part of the IR that is taken into account for the "expensive checks".
737 // A greedy pattern rewrite is not allowed to erase the parent op of the scope
738 // region, as that would break the worklist handling and the expensive checks.
739 if (Region *scope = config.getScope(); scope->getParentOp() == op)
740 llvm_unreachable(
741 "scope region must not be erased during greedy pattern rewrite");
742#endif // NDEBUG
743
744 if (RewriterBase::Listener *listener = config.getListener())
745 listener->notifyOperationErased(op);
746
747 addOperandsToWorklist(op);
748 worklist.remove(op);
749
750 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
751 strictModeFilteredOps.erase(V: op);
752}
753
754void GreedyPatternRewriteDriver::notifyOperationReplaced(
755 Operation *op, ValueRange replacement) {
756 LLVM_DEBUG({
757 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
758 << ")\n";
759 });
760 if (RewriterBase::Listener *listener = config.getListener())
761 listener->notifyOperationReplaced(op, replacement);
762}
763
764void GreedyPatternRewriteDriver::notifyMatchFailure(
765 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
766 LLVM_DEBUG({
767 Diagnostic diag(loc, DiagnosticSeverity::Remark);
768 reasonCallback(diag);
769 logger.startLine() << "** Match Failure : " << diag.str() << "\n";
770 });
771 if (RewriterBase::Listener *listener = config.getListener())
772 listener->notifyMatchFailure(loc, reasonCallback);
773}
774
775//===----------------------------------------------------------------------===//
776// RegionPatternRewriteDriver
777//===----------------------------------------------------------------------===//
778
779namespace {
780/// This driver simplfies all ops in a region.
781class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
782public:
783 explicit RegionPatternRewriteDriver(MLIRContext *ctx,
784 const FrozenRewritePatternSet &patterns,
785 const GreedyRewriteConfig &config,
786 Region &regions);
787
788 /// Simplify ops inside `region` and simplify the region itself. Return
789 /// success if the transformation converged.
790 LogicalResult simplify(bool *changed) &&;
791
792private:
793 /// The region that is simplified.
794 Region &region;
795};
796} // namespace
797
798RegionPatternRewriteDriver::RegionPatternRewriteDriver(
799 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
800 const GreedyRewriteConfig &config, Region &region)
801 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
802 // Populate strict mode ops.
803 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
804 region.walk(callback: [&](Operation *op) { strictModeFilteredOps.insert(V: op); });
805 }
806}
807
808namespace {
809class GreedyPatternRewriteIteration
810 : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
811public:
812 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
813 GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
814 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
815 iteration(iteration) {}
816 static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
817 void print(raw_ostream &os) const override {
818 os << "GreedyPatternRewriteIteration(" << iteration << ")";
819 }
820
821private:
822 int64_t iteration = 0;
823};
824} // namespace
825
826LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
827 bool continueRewrites = false;
828 int64_t iteration = 0;
829 MLIRContext *ctx = rewriter.getContext();
830 do {
831 // Check if the iteration limit was reached.
832 if (++iteration > config.getMaxIterations() &&
833 config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
834 break;
835
836 // New iteration: start with an empty worklist.
837 worklist.clear();
838
839 // `OperationFolder` CSE's constant ops (and may move them into parents
840 // regions to enable more aggressive CSE'ing).
841 OperationFolder folder(ctx, this);
842 auto insertKnownConstant = [&](Operation *op) {
843 // Check for existing constants when populating the worklist. This avoids
844 // accidentally reversing the constant order during processing.
845 Attribute constValue;
846 if (matchPattern(op, pattern: m_Constant(bind_value: &constValue)))
847 if (!folder.insertKnownConstant(op, constValue))
848 return true;
849 return false;
850 };
851
852 if (!config.getUseTopDownTraversal()) {
853 // Add operations to the worklist in postorder.
854 region.walk(callback: [&](Operation *op) {
855 if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
856 addToWorklist(op);
857 });
858 } else {
859 // Add all nested operations to the worklist in preorder.
860 region.walk<WalkOrder::PreOrder>(callback: [&](Operation *op) {
861 if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
862 addToWorklist(op);
863 return WalkResult::advance();
864 }
865 return WalkResult::skip();
866 });
867
868 // Reverse the list so our pop-back loop processes them in-order.
869 worklist.reverse();
870 }
871
872 ctx->executeAction<GreedyPatternRewriteIteration>(
873 actionFn: [&] {
874 continueRewrites = false;
875
876 // Erase unreachable blocks
877 // Operations like:
878 // %add = arith.addi %add, %add : i64
879 // are legal in unreachable code. Unfortunately many patterns would be
880 // unsafe to apply on such IR and can lead to crashes or infinite
881 // loops.
882 continueRewrites |=
883 succeeded(Result: eraseUnreachableBlocks(rewriter, regions: region));
884
885 continueRewrites |= processWorklist();
886
887 // After applying patterns, make sure that the CFG of each of the
888 // regions is kept up to date.
889 if (config.getRegionSimplificationLevel() !=
890 GreedySimplifyRegionLevel::Disabled) {
891 continueRewrites |= succeeded(Result: simplifyRegions(
892 rewriter, regions: region,
893 /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
894 GreedySimplifyRegionLevel::Aggressive));
895 }
896 },
897 irUnits: {&region}, args&: iteration);
898 } while (continueRewrites);
899
900 if (changed)
901 *changed = iteration > 1;
902
903 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
904 return success(IsSuccess: !continueRewrites);
905}
906
907LogicalResult
908mlir::applyPatternsGreedily(Region &region,
909 const FrozenRewritePatternSet &patterns,
910 GreedyRewriteConfig config, bool *changed) {
911 // The top-level operation must be known to be isolated from above to
912 // prevent performing canonicalizations on operations defined at or above
913 // the region containing 'op'.
914 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
915 "patterns can only be applied to operations IsolatedFromAbove");
916
917 // Set scope if not specified.
918 if (!config.getScope())
919 config.setScope(&region);
920
921#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
922 if (failed(verify(config.getScope()->getParentOp())))
923 llvm::report_fatal_error(
924 "greedy pattern rewriter input IR failed to verify");
925#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
926
927 // Start the pattern driver.
928 RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
929 region);
930 LogicalResult converged = std::move(driver).simplify(changed);
931 LLVM_DEBUG(if (failed(converged)) {
932 llvm::dbgs() << "The pattern rewrite did not converge after scanning "
933 << config.getMaxIterations() << " times\n";
934 });
935 return converged;
936}
937
938//===----------------------------------------------------------------------===//
939// MultiOpPatternRewriteDriver
940//===----------------------------------------------------------------------===//
941
942namespace {
943/// This driver simplfies a list of ops.
944class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
945public:
946 explicit MultiOpPatternRewriteDriver(
947 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
948 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
949 llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
950
951 /// Simplify `ops`. Return `success` if the transformation converged.
952 LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
953
954private:
955 void notifyOperationErased(Operation *op) override {
956 GreedyPatternRewriteDriver::notifyOperationErased(op);
957 if (survivingOps)
958 survivingOps->erase(V: op);
959 }
960
961 /// An optional set of ops that survived the rewrite. This set is populated
962 /// at the beginning of `simplifyLocally` with the inititally provided list
963 /// of ops.
964 llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
965};
966} // namespace
967
968MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
969 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
970 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
971 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
972 : GreedyPatternRewriteDriver(ctx, patterns, config),
973 survivingOps(survivingOps) {
974 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
975 strictModeFilteredOps.insert_range(R&: ops);
976
977 if (survivingOps) {
978 survivingOps->clear();
979 survivingOps->insert_range(R&: ops);
980 }
981}
982
983LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
984 bool *changed) && {
985 // Populate the initial worklist.
986 for (Operation *op : ops)
987 addSingleOpToWorklist(op);
988
989 // Process ops on the worklist.
990 bool result = processWorklist();
991 if (changed)
992 *changed = result;
993
994 return success(IsSuccess: worklist.empty());
995}
996
997/// Find the region that is the closest common ancestor of all given ops.
998///
999/// Note: This function returns `nullptr` if there is a top-level op among the
1000/// given list of ops.
1001static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
1002 assert(!ops.empty() && "expected at least one op");
1003 // Fast path in case there is only one op.
1004 if (ops.size() == 1)
1005 return ops.front()->getParentRegion();
1006
1007 Region *region = ops.front()->getParentRegion();
1008 ops = ops.drop_front();
1009 int sz = ops.size();
1010 llvm::BitVector remainingOps(sz, true);
1011 while (region) {
1012 int pos = -1;
1013 // Iterate over all remaining ops.
1014 while ((pos = remainingOps.find_first_in(Begin: pos + 1, End: sz)) != -1) {
1015 // Is this op contained in `region`?
1016 if (region->findAncestorOpInRegion(op&: *ops[pos]))
1017 remainingOps.reset(Idx: pos);
1018 }
1019 if (remainingOps.none())
1020 break;
1021 region = region->getParentRegion();
1022 }
1023 return region;
1024}
1025
1026LogicalResult mlir::applyOpPatternsGreedily(
1027 ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
1028 GreedyRewriteConfig config, bool *changed, bool *allErased) {
1029 if (ops.empty()) {
1030 if (changed)
1031 *changed = false;
1032 if (allErased)
1033 *allErased = true;
1034 return success();
1035 }
1036
1037 // Determine scope of rewrite.
1038 if (!config.getScope()) {
1039 // Compute scope if none was provided. The scope will remain `nullptr` if
1040 // there is a top-level op among `ops`.
1041 config.setScope(findCommonAncestor(ops));
1042 } else {
1043 // If a scope was provided, make sure that all ops are in scope.
1044#ifndef NDEBUG
1045 bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
1046 return static_cast<bool>(config.getScope()->findAncestorOpInRegion(*op));
1047 });
1048 assert(allOpsInScope && "ops must be within the specified scope");
1049#endif // NDEBUG
1050 }
1051
1052#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1053 if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
1054 llvm::report_fatal_error(
1055 "greedy pattern rewriter input IR failed to verify");
1056#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1057
1058 // Start the pattern driver.
1059 llvm::SmallDenseSet<Operation *, 4> surviving;
1060 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1061 config, ops,
1062 allErased ? &surviving : nullptr);
1063 LogicalResult converged = std::move(driver).simplify(ops, changed);
1064 if (allErased)
1065 *allErased = surviving.empty();
1066 LLVM_DEBUG(if (failed(converged)) {
1067 llvm::dbgs() << "The pattern rewrite did not converge after "
1068 << config.getMaxNumRewrites() << " rewrites";
1069 });
1070 return converged;
1071}
1072

source code of mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp