1//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements mlir::applyPatternsAndFoldGreedily.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14
15#include "mlir/Config/mlir-config.h"
16#include "mlir/IR/Action.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Verifier.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Rewrite/PatternApplicator.h"
21#include "mlir/Transforms/FoldUtils.h"
22#include "mlir/Transforms/RegionUtils.h"
23#include "llvm/ADT/BitVector.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/ScopeExit.h"
26#include "llvm/Support/CommandLine.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/ScopedPrinter.h"
29#include "llvm/Support/raw_ostream.h"
30
31#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32#include <random>
33#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
34
35using namespace mlir;
36
37#define DEBUG_TYPE "greedy-rewriter"
38
39namespace {
40
41//===----------------------------------------------------------------------===//
42// Debugging Infrastructure
43//===----------------------------------------------------------------------===//
44
45#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46/// A helper struct that performs various "expensive checks" to detect broken
47/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48/// broken if:
49/// * IR does not verify after pattern application / folding.
50/// * Pattern returns "failure" but the IR has changed.
51/// * Pattern returns "success" but the IR has not changed.
52///
53/// This struct stores finger prints of ops to determine whether the IR has
54/// changed or not.
55struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56 ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57 : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
58
59 /// Compute finger prints of the given op and its nested ops.
60 void computeFingerPrints(Operation *topLevel) {
61 this->topLevel = topLevel;
62 this->topLevelFingerPrint.emplace(topLevel);
63 topLevel->walk([&](Operation *op) {
64 fingerprints.try_emplace(op, op, /*includeNested=*/false);
65 });
66 }
67
68 /// Clear all finger prints.
69 void clear() {
70 topLevel = nullptr;
71 topLevelFingerPrint.reset();
72 fingerprints.clear();
73 }
74
75 void notifyRewriteSuccess() {
76 if (!topLevel)
77 return;
78
79 // Make sure that the IR still verifies.
80 if (failed(verify(topLevel)))
81 llvm::report_fatal_error("IR failed to verify after pattern application");
82
83 // Pattern application success => IR must have changed.
84 OperationFingerPrint afterFingerPrint(topLevel);
85 if (*topLevelFingerPrint == afterFingerPrint) {
86 // Note: Run "mlir-opt -debug" to see which pattern is broken.
87 llvm::report_fatal_error(
88 "pattern returned success but IR did not change");
89 }
90 for (const auto &it : fingerprints) {
91 // Skip top-level op, its finger print is never invalidated.
92 if (it.first == topLevel)
93 continue;
94 // Note: Finger print computation may crash when an op was erased
95 // without notifying the rewriter. (Run with ASAN to see where the op was
96 // erased; the op was probably erased directly, bypassing the rewriter
97 // API.) Finger print computation does may not crash if a new op was
98 // created at the same memory location. (But then the finger print should
99 // have changed.)
100 if (it.second !=
101 OperationFingerPrint(it.first, /*includeNested=*/false)) {
102 // Note: Run "mlir-opt -debug" to see which pattern is broken.
103 llvm::report_fatal_error("operation finger print changed");
104 }
105 }
106 }
107
108 void notifyRewriteFailure() {
109 if (!topLevel)
110 return;
111
112 // Pattern application failure => IR must not have changed.
113 OperationFingerPrint afterFingerPrint(topLevel);
114 if (*topLevelFingerPrint != afterFingerPrint) {
115 // Note: Run "mlir-opt -debug" to see which pattern is broken.
116 llvm::report_fatal_error("pattern returned failure but IR did change");
117 }
118 }
119
120 void notifyFoldingSuccess() {
121 if (!topLevel)
122 return;
123
124 // Make sure that the IR still verifies.
125 if (failed(verify(topLevel)))
126 llvm::report_fatal_error("IR failed to verify after folding");
127 }
128
129protected:
130 /// Invalidate the finger print of the given op, i.e., remove it from the map.
131 void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
132
133 void notifyBlockErased(Block *block) override {
134 RewriterBase::ForwardingListener::notifyBlockErased(block);
135
136 // The block structure (number of blocks, types of block arguments, etc.)
137 // is part of the fingerprint of the parent op.
138 // TODO: The parent op fingerprint should also be invalidated when modifying
139 // the block arguments of a block, but we do not have a
140 // `notifyBlockModified` callback yet.
141 invalidateFingerPrint(block->getParentOp());
142 }
143
144 void notifyOperationInserted(Operation *op,
145 OpBuilder::InsertPoint previous) override {
146 RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
147 invalidateFingerPrint(op->getParentOp());
148 }
149
150 void notifyOperationModified(Operation *op) override {
151 RewriterBase::ForwardingListener::notifyOperationModified(op);
152 invalidateFingerPrint(op);
153 }
154
155 void notifyOperationErased(Operation *op) override {
156 RewriterBase::ForwardingListener::notifyOperationErased(op);
157 op->walk([this](Operation *op) { invalidateFingerPrint(op); });
158 }
159
160 /// Operation finger prints to detect invalid pattern API usage. IR is checked
161 /// against these finger prints after pattern application to detect cases
162 /// where IR was modified directly, bypassing the rewriter API.
163 DenseMap<Operation *, OperationFingerPrint> fingerprints;
164
165 /// Top-level operation of the current greedy rewrite.
166 Operation *topLevel = nullptr;
167
168 /// Finger print of the top-level operation.
169 std::optional<OperationFingerPrint> topLevelFingerPrint;
170};
171#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
172
173#ifndef NDEBUG
174static Operation *getDumpRootOp(Operation *op) {
175 // Dump the parent op so that materialized constants are visible. If the op
176 // is a top-level op, dump it directly.
177 if (Operation *parentOp = op->getParentOp())
178 return parentOp;
179 return op;
180}
181static void logSuccessfulFolding(Operation *op) {
182 llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183 op->dump();
184 llvm::dbgs() << "\n\n";
185}
186#endif // NDEBUG
187
188//===----------------------------------------------------------------------===//
189// Worklist
190//===----------------------------------------------------------------------===//
191
192/// A LIFO worklist of operations with efficient removal and set semantics.
193///
194/// This class maintains a vector of operations and a mapping of operations to
195/// positions in the vector, so that operations can be removed efficiently at
196/// random. When an operation is removed, it is replaced with nullptr. Such
197/// nullptr are skipped when pop'ing elements.
198class Worklist {
199public:
200 Worklist();
201
202 /// Clear the worklist.
203 void clear();
204
205 /// Return whether the worklist is empty.
206 bool empty() const;
207
208 /// Push an operation to the end of the worklist, unless the operation is
209 /// already on the worklist.
210 void push(Operation *op);
211
212 /// Pop the an operation from the end of the worklist. Only allowed on
213 /// non-empty worklists.
214 Operation *pop();
215
216 /// Remove an operation from the worklist.
217 void remove(Operation *op);
218
219 /// Reverse the worklist.
220 void reverse();
221
222protected:
223 /// The worklist of operations.
224 std::vector<Operation *> list;
225
226 /// A mapping of operations to positions in `list`.
227 DenseMap<Operation *, unsigned> map;
228};
229
230Worklist::Worklist() { list.reserve(n: 64); }
231
232void Worklist::clear() {
233 list.clear();
234 map.clear();
235}
236
237bool Worklist::empty() const {
238 // Skip all nullptr.
239 return !llvm::any_of(Range: list,
240 P: [](Operation *op) { return static_cast<bool>(op); });
241}
242
243void Worklist::push(Operation *op) {
244 assert(op && "cannot push nullptr to worklist");
245 // Check to see if the worklist already contains this op.
246 if (!map.insert(KV: {op, list.size()}).second)
247 return;
248 list.push_back(x: op);
249}
250
251Operation *Worklist::pop() {
252 assert(!empty() && "cannot pop from empty worklist");
253 // Skip and remove all trailing nullptr.
254 while (!list.back())
255 list.pop_back();
256 Operation *op = list.back();
257 list.pop_back();
258 map.erase(Val: op);
259 // Cleanup: Remove all trailing nullptr.
260 while (!list.empty() && !list.back())
261 list.pop_back();
262 return op;
263}
264
265void Worklist::remove(Operation *op) {
266 assert(op && "cannot remove nullptr from worklist");
267 auto it = map.find(Val: op);
268 if (it != map.end()) {
269 assert(list[it->second] == op && "malformed worklist data structure");
270 list[it->second] = nullptr;
271 map.erase(I: it);
272 }
273}
274
275void Worklist::reverse() {
276 std::reverse(first: list.begin(), last: list.end());
277 for (size_t i = 0, e = list.size(); i != e; ++i)
278 map[list[i]] = i;
279}
280
281#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282/// A worklist that pops elements at a random position. This worklist is for
283/// testing/debugging purposes only. It can be used to ensure that lowering
284/// pipelines work correctly regardless of the order in which ops are processed
285/// by the GreedyPatternRewriteDriver.
286class RandomizedWorklist : public Worklist {
287public:
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290 }
291
292 /// Pop a random non-empty op from the worklist.
293 Operation *pop() {
294 Operation *op = nullptr;
295 do {
296 assert(!list.empty() && "cannot pop from empty worklist");
297 int64_t pos = generator() % list.size();
298 op = list[pos];
299 list.erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
301 map[list[i]] = i;
302 map.erase(op);
303 } while (!op);
304 return op;
305 }
306
307private:
308 std::minstd_rand0 generator;
309};
310#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311
312//===----------------------------------------------------------------------===//
313// GreedyPatternRewriteDriver
314//===----------------------------------------------------------------------===//
315
316/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317/// applies the locally optimal patterns.
318///
319/// This abstract class manages the worklist and contains helper methods for
320/// rewriting ops on the worklist. Derived classes specify how ops are added
321/// to the worklist in the beginning.
322class GreedyPatternRewriteDriver : public PatternRewriter,
323 public RewriterBase::Listener {
324protected:
325 explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
326 const FrozenRewritePatternSet &patterns,
327 const GreedyRewriteConfig &config);
328
329 /// Add the given operation to the worklist.
330 void addSingleOpToWorklist(Operation *op);
331
332 /// Add the given operation and its ancestors to the worklist.
333 void addToWorklist(Operation *op);
334
335 /// Notify the driver that the specified operation may have been modified
336 /// in-place. The operation is added to the worklist.
337 void notifyOperationModified(Operation *op) override;
338
339 /// Notify the driver that the specified operation was inserted. Update the
340 /// worklist as needed: The operation is enqueued depending on scope and
341 /// strict mode.
342 void notifyOperationInserted(Operation *op, InsertPoint previous) override;
343
344 /// Notify the driver that the specified operation was removed. Update the
345 /// worklist as needed: The operation and its children are removed from the
346 /// worklist.
347 void notifyOperationErased(Operation *op) override;
348
349 /// Notify the driver that the specified operation was replaced. Update the
350 /// worklist as needed: New users are added enqueued.
351 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352
353 /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354 /// reached. Return `true` if any IR was changed.
355 bool processWorklist();
356
357 /// The worklist for this transformation keeps track of the operations that
358 /// need to be (re)visited.
359#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
360 RandomizedWorklist worklist;
361#else
362 Worklist worklist;
363#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364
365 /// Configuration information for how to simplify.
366 const GreedyRewriteConfig config;
367
368 /// The list of ops we are restricting our rewrites to. These include the
369 /// supplied set of ops as well as new ops created while rewriting those ops
370 /// depending on `strictMode`. This set is not maintained when
371 /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
372 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
373
374private:
375 /// Look over the provided operands for any defining operations that should
376 /// be re-added to the worklist. This function should be called when an
377 /// operation is modified or removed, as it may trigger further
378 /// simplifications.
379 void addOperandsToWorklist(Operation *op);
380
381 /// Notify the driver that the given block was inserted.
382 void notifyBlockInserted(Block *block, Region *previous,
383 Region::iterator previousIt) override;
384
385 /// Notify the driver that the given block is about to be removed.
386 void notifyBlockErased(Block *block) override;
387
388 /// For debugging only: Notify the driver of a pattern match failure.
389 void
390 notifyMatchFailure(Location loc,
391 function_ref<void(Diagnostic &)> reasonCallback) override;
392
393#ifndef NDEBUG
394 /// A logger used to emit information during the application process.
395 llvm::ScopedPrinter logger{llvm::dbgs()};
396#endif
397
398 /// The low-level pattern applicator.
399 PatternApplicator matcher;
400
401#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
402 ExpensiveChecks expensiveChecks;
403#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
404};
405} // namespace
406
407GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
408 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
409 const GreedyRewriteConfig &config)
410 : PatternRewriter(ctx), config(config), matcher(patterns)
411#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
412 // clang-format off
413 , expensiveChecks(
414 /*driver=*/this,
415 /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
416// clang-format on
417#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
418{
419 // Apply a simple cost model based solely on pattern benefit.
420 matcher.applyDefaultCostModel();
421
422 // Set up listener.
423#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
424 // Send IR notifications to the debug handler. This handler will then forward
425 // all notifications to this GreedyPatternRewriteDriver.
426 setListener(&expensiveChecks);
427#else
428 setListener(this);
429#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
430}
431
432bool GreedyPatternRewriteDriver::processWorklist() {
433#ifndef NDEBUG
434 const char *logLineComment =
435 "//===-------------------------------------------===//\n";
436
437 /// A utility function to log a process result for the given reason.
438 auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
439 logger.unindent();
440 logger.startLine() << "} -> " << result;
441 if (!msg.isTriviallyEmpty())
442 logger.getOStream() << " : " << msg;
443 logger.getOStream() << "\n";
444 };
445 auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
446 logResult(result, msg);
447 logger.startLine() << logLineComment;
448 };
449#endif
450
451 bool changed = false;
452 int64_t numRewrites = 0;
453 while (!worklist.empty() &&
454 (numRewrites < config.maxNumRewrites ||
455 config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
456 auto *op = worklist.pop();
457
458 LLVM_DEBUG({
459 logger.getOStream() << "\n";
460 logger.startLine() << logLineComment;
461 logger.startLine() << "Processing operation : '" << op->getName() << "'("
462 << op << ") {\n";
463 logger.indent();
464
465 // If the operation has no regions, just print it here.
466 if (op->getNumRegions() == 0) {
467 op->print(
468 logger.startLine(),
469 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
470 logger.getOStream() << "\n\n";
471 }
472 });
473
474 // If the operation is trivially dead - remove it.
475 if (isOpTriviallyDead(op)) {
476 eraseOp(op);
477 changed = true;
478
479 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
480 continue;
481 }
482
483 // Try to fold this op. Do not fold constant ops. That would lead to an
484 // infinite folding loop, as every constant op would be folded to an
485 // Attribute and then immediately be rematerialized as a constant op, which
486 // is then put on the worklist.
487 if (!op->hasTrait<OpTrait::ConstantLike>()) {
488 SmallVector<OpFoldResult> foldResults;
489 if (succeeded(result: op->fold(results&: foldResults))) {
490 LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
491#ifndef NDEBUG
492 Operation *dumpRootOp = getDumpRootOp(op);
493#endif // NDEBUG
494 if (foldResults.empty()) {
495 // Op was modified in-place.
496 notifyOperationModified(op);
497 changed = true;
498 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
499#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
500 expensiveChecks.notifyFoldingSuccess();
501#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
502 continue;
503 }
504
505 // Op results can be replaced with `foldResults`.
506 assert(foldResults.size() == op->getNumResults() &&
507 "folder produced incorrect number of results");
508 OpBuilder::InsertionGuard g(*this);
509 setInsertionPoint(op);
510 SmallVector<Value> replacements;
511 bool materializationSucceeded = true;
512 for (auto [ofr, resultType] :
513 llvm::zip_equal(t&: foldResults, u: op->getResultTypes())) {
514 if (auto value = ofr.dyn_cast<Value>()) {
515 assert(value.getType() == resultType &&
516 "folder produced value of incorrect type");
517 replacements.push_back(Elt: value);
518 continue;
519 }
520 // Materialize Attributes as SSA values.
521 Operation *constOp = op->getDialect()->materializeConstant(
522 builder&: *this, value: ofr.get<Attribute>(), type: resultType, loc: op->getLoc());
523
524 if (!constOp) {
525 // If materialization fails, cleanup any operations generated for
526 // the previous results.
527 llvm::SmallDenseSet<Operation *> replacementOps;
528 for (Value replacement : replacements) {
529 assert(replacement.use_empty() &&
530 "folder reused existing op for one result but constant "
531 "materialization failed for another result");
532 replacementOps.insert(V: replacement.getDefiningOp());
533 }
534 for (Operation *op : replacementOps) {
535 eraseOp(op);
536 }
537
538 materializationSucceeded = false;
539 break;
540 }
541
542 assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
543 "materializeConstant produced op that is not a ConstantLike");
544 assert(constOp->getResultTypes()[0] == resultType &&
545 "materializeConstant produced incorrect result type");
546 replacements.push_back(Elt: constOp->getResult(idx: 0));
547 }
548
549 if (materializationSucceeded) {
550 replaceOp(op, newValues: replacements);
551 changed = true;
552 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
553#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
554 expensiveChecks.notifyFoldingSuccess();
555#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
556 continue;
557 }
558 }
559 }
560
561 // Try to match one of the patterns. The rewriter is automatically
562 // notified of any necessary changes, so there is nothing else to do
563 // here.
564 auto canApplyCallback = [&](const Pattern &pattern) {
565 LLVM_DEBUG({
566 logger.getOStream() << "\n";
567 logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
568 << op->getName() << " -> (";
569 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
570 logger.getOStream() << ")' {\n";
571 logger.indent();
572 });
573 if (config.listener)
574 config.listener->notifyPatternBegin(pattern, op);
575 return true;
576 };
577 function_ref<bool(const Pattern &)> canApply = canApplyCallback;
578 auto onFailureCallback = [&](const Pattern &pattern) {
579 LLVM_DEBUG(logResult("failure", "pattern failed to match"));
580 if (config.listener)
581 config.listener->notifyPatternEnd(pattern, status: failure());
582 };
583 function_ref<void(const Pattern &)> onFailure = onFailureCallback;
584 auto onSuccessCallback = [&](const Pattern &pattern) {
585 LLVM_DEBUG(logResult("success", "pattern applied successfully"));
586 if (config.listener)
587 config.listener->notifyPatternEnd(pattern, status: success());
588 return success();
589 };
590 function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
591
592#ifdef NDEBUG
593 // Optimization: PatternApplicator callbacks are not needed when running in
594 // optimized mode and without a listener.
595 if (!config.listener) {
596 canApply = nullptr;
597 onFailure = nullptr;
598 onSuccess = nullptr;
599 }
600#endif // NDEBUG
601
602#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
603 if (config.scope) {
604 expensiveChecks.computeFingerPrints(config.scope->getParentOp());
605 }
606 auto clearFingerprints =
607 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
608#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
609
610 LogicalResult matchResult =
611 matcher.matchAndRewrite(op, rewriter&: *this, canApply, onFailure, onSuccess);
612
613 if (succeeded(result: matchResult)) {
614 LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
615#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
616 expensiveChecks.notifyRewriteSuccess();
617#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
618 changed = true;
619 ++numRewrites;
620 } else {
621 LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
622#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
623 expensiveChecks.notifyRewriteFailure();
624#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
625 }
626 }
627
628 return changed;
629}
630
631void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
632 assert(op && "expected valid op");
633 // Gather potential ancestors while looking for a "scope" parent region.
634 SmallVector<Operation *, 8> ancestors;
635 Region *region = nullptr;
636 do {
637 ancestors.push_back(Elt: op);
638 region = op->getParentRegion();
639 if (config.scope == region) {
640 // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
641 for (Operation *op : ancestors)
642 addSingleOpToWorklist(op);
643 return;
644 }
645 if (region == nullptr)
646 return;
647 } while ((op = region->getParentOp()));
648}
649
650void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
651 if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
652 strictModeFilteredOps.contains(V: op))
653 worklist.push(op);
654}
655
656void GreedyPatternRewriteDriver::notifyBlockInserted(
657 Block *block, Region *previous, Region::iterator previousIt) {
658 if (config.listener)
659 config.listener->notifyBlockInserted(block, previous, previousIt);
660}
661
662void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
663 if (config.listener)
664 config.listener->notifyBlockErased(block);
665}
666
667void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
668 InsertPoint previous) {
669 LLVM_DEBUG({
670 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
671 << ")\n";
672 });
673 if (config.listener)
674 config.listener->notifyOperationInserted(op, previous);
675 if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
676 strictModeFilteredOps.insert(V: op);
677 addToWorklist(op);
678}
679
680void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
681 LLVM_DEBUG({
682 logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
683 << ")\n";
684 });
685 if (config.listener)
686 config.listener->notifyOperationModified(op);
687 addToWorklist(op);
688}
689
690void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
691 for (Value operand : op->getOperands()) {
692 // If this operand currently has at most 2 users, add its defining op to the
693 // worklist. Indeed, after the op is deleted, then the operand will have at
694 // most 1 user left. If it has 0 users left, it can be deleted too,
695 // and if it has 1 user left, there may be further canonicalization
696 // opportunities.
697 if (!operand)
698 continue;
699
700 auto *defOp = operand.getDefiningOp();
701 if (!defOp)
702 continue;
703
704 Operation *otherUser = nullptr;
705 bool hasMoreThanTwoUses = false;
706 for (auto user : operand.getUsers()) {
707 if (user == op || user == otherUser)
708 continue;
709 if (!otherUser) {
710 otherUser = user;
711 continue;
712 }
713 hasMoreThanTwoUses = true;
714 break;
715 }
716 if (hasMoreThanTwoUses)
717 continue;
718
719 addToWorklist(op: defOp);
720 }
721}
722
723void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
724 LLVM_DEBUG({
725 logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
726 << ")\n";
727 });
728
729#ifndef NDEBUG
730 // Only ops that are within the configured scope are added to the worklist of
731 // the greedy pattern rewriter. Moreover, the parent op of the scope region is
732 // the part of the IR that is taken into account for the "expensive checks".
733 // A greedy pattern rewrite is not allowed to erase the parent op of the scope
734 // region, as that would break the worklist handling and the expensive checks.
735 if (config.scope && config.scope->getParentOp() == op)
736 llvm_unreachable(
737 "scope region must not be erased during greedy pattern rewrite");
738#endif // NDEBUG
739
740 if (config.listener)
741 config.listener->notifyOperationErased(op);
742
743 addOperandsToWorklist(op);
744 worklist.remove(op);
745
746 if (config.strictMode != GreedyRewriteStrictness::AnyOp)
747 strictModeFilteredOps.erase(V: op);
748}
749
750void GreedyPatternRewriteDriver::notifyOperationReplaced(
751 Operation *op, ValueRange replacement) {
752 LLVM_DEBUG({
753 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
754 << ")\n";
755 });
756 if (config.listener)
757 config.listener->notifyOperationReplaced(op, replacement);
758}
759
760void GreedyPatternRewriteDriver::notifyMatchFailure(
761 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
762 LLVM_DEBUG({
763 Diagnostic diag(loc, DiagnosticSeverity::Remark);
764 reasonCallback(diag);
765 logger.startLine() << "** Match Failure : " << diag.str() << "\n";
766 });
767 if (config.listener)
768 config.listener->notifyMatchFailure(loc, reasonCallback);
769}
770
771//===----------------------------------------------------------------------===//
772// RegionPatternRewriteDriver
773//===----------------------------------------------------------------------===//
774
775namespace {
776/// This driver simplfies all ops in a region.
777class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
778public:
779 explicit RegionPatternRewriteDriver(MLIRContext *ctx,
780 const FrozenRewritePatternSet &patterns,
781 const GreedyRewriteConfig &config,
782 Region &regions);
783
784 /// Simplify ops inside `region` and simplify the region itself. Return
785 /// success if the transformation converged.
786 LogicalResult simplify(bool *changed) &&;
787
788private:
789 /// The region that is simplified.
790 Region &region;
791};
792} // namespace
793
794RegionPatternRewriteDriver::RegionPatternRewriteDriver(
795 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
796 const GreedyRewriteConfig &config, Region &region)
797 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
798 // Populate strict mode ops.
799 if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
800 region.walk(callback: [&](Operation *op) { strictModeFilteredOps.insert(V: op); });
801 }
802}
803
804namespace {
805class GreedyPatternRewriteIteration
806 : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
807public:
808 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
809 GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
810 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
811 iteration(iteration) {}
812 static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
813 void print(raw_ostream &os) const override {
814 os << "GreedyPatternRewriteIteration(" << iteration << ")";
815 }
816
817private:
818 int64_t iteration = 0;
819};
820} // namespace
821
822LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
823 bool continueRewrites = false;
824 int64_t iteration = 0;
825 MLIRContext *ctx = getContext();
826 do {
827 // Check if the iteration limit was reached.
828 if (++iteration > config.maxIterations &&
829 config.maxIterations != GreedyRewriteConfig::kNoLimit)
830 break;
831
832 // New iteration: start with an empty worklist.
833 worklist.clear();
834
835 // `OperationFolder` CSE's constant ops (and may move them into parents
836 // regions to enable more aggressive CSE'ing).
837 OperationFolder folder(getContext(), this);
838 auto insertKnownConstant = [&](Operation *op) {
839 // Check for existing constants when populating the worklist. This avoids
840 // accidentally reversing the constant order during processing.
841 Attribute constValue;
842 if (matchPattern(op, pattern: m_Constant(bind_value: &constValue)))
843 if (!folder.insertKnownConstant(op, constValue))
844 return true;
845 return false;
846 };
847
848 if (!config.useTopDownTraversal) {
849 // Add operations to the worklist in postorder.
850 region.walk(callback: [&](Operation *op) {
851 if (!insertKnownConstant(op))
852 addToWorklist(op);
853 });
854 } else {
855 // Add all nested operations to the worklist in preorder.
856 region.walk<WalkOrder::PreOrder>(callback: [&](Operation *op) {
857 if (!insertKnownConstant(op)) {
858 addToWorklist(op);
859 return WalkResult::advance();
860 }
861 return WalkResult::skip();
862 });
863
864 // Reverse the list so our pop-back loop processes them in-order.
865 worklist.reverse();
866 }
867
868 ctx->executeAction<GreedyPatternRewriteIteration>(
869 actionFn: [&] {
870 continueRewrites = processWorklist();
871
872 // After applying patterns, make sure that the CFG of each of the
873 // regions is kept up to date.
874 if (config.enableRegionSimplification)
875 continueRewrites |= succeeded(result: simplifyRegions(rewriter&: *this, regions: region));
876 },
877 irUnits: {&region}, args&: iteration);
878 } while (continueRewrites);
879
880 if (changed)
881 *changed = iteration > 1;
882
883 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
884 return success(isSuccess: !continueRewrites);
885}
886
887LogicalResult
888mlir::applyPatternsAndFoldGreedily(Region &region,
889 const FrozenRewritePatternSet &patterns,
890 GreedyRewriteConfig config, bool *changed) {
891 // The top-level operation must be known to be isolated from above to
892 // prevent performing canonicalizations on operations defined at or above
893 // the region containing 'op'.
894 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
895 "patterns can only be applied to operations IsolatedFromAbove");
896
897 // Set scope if not specified.
898 if (!config.scope)
899 config.scope = &region;
900
901#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
902 if (failed(verify(config.scope->getParentOp())))
903 llvm::report_fatal_error(
904 "greedy pattern rewriter input IR failed to verify");
905#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
906
907 // Start the pattern driver.
908 RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
909 region);
910 LogicalResult converged = std::move(driver).simplify(changed);
911 LLVM_DEBUG(if (failed(converged)) {
912 llvm::dbgs() << "The pattern rewrite did not converge after scanning "
913 << config.maxIterations << " times\n";
914 });
915 return converged;
916}
917
918//===----------------------------------------------------------------------===//
919// MultiOpPatternRewriteDriver
920//===----------------------------------------------------------------------===//
921
922namespace {
923/// This driver simplfies a list of ops.
924class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
925public:
926 explicit MultiOpPatternRewriteDriver(
927 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
928 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
929 llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
930
931 /// Simplify `ops`. Return `success` if the transformation converged.
932 LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
933
934private:
935 void notifyOperationErased(Operation *op) override {
936 GreedyPatternRewriteDriver::notifyOperationErased(op);
937 if (survivingOps)
938 survivingOps->erase(V: op);
939 }
940
941 /// An optional set of ops that survived the rewrite. This set is populated
942 /// at the beginning of `simplifyLocally` with the inititally provided list
943 /// of ops.
944 llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
945};
946} // namespace
947
948MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
949 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
950 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
951 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
952 : GreedyPatternRewriteDriver(ctx, patterns, config),
953 survivingOps(survivingOps) {
954 if (config.strictMode != GreedyRewriteStrictness::AnyOp)
955 strictModeFilteredOps.insert(I: ops.begin(), E: ops.end());
956
957 if (survivingOps) {
958 survivingOps->clear();
959 survivingOps->insert(I: ops.begin(), E: ops.end());
960 }
961}
962
963LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
964 bool *changed) && {
965 // Populate the initial worklist.
966 for (Operation *op : ops)
967 addSingleOpToWorklist(op);
968
969 // Process ops on the worklist.
970 bool result = processWorklist();
971 if (changed)
972 *changed = result;
973
974 return success(isSuccess: worklist.empty());
975}
976
977/// Find the region that is the closest common ancestor of all given ops.
978///
979/// Note: This function returns `nullptr` if there is a top-level op among the
980/// given list of ops.
981static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
982 assert(!ops.empty() && "expected at least one op");
983 // Fast path in case there is only one op.
984 if (ops.size() == 1)
985 return ops.front()->getParentRegion();
986
987 Region *region = ops.front()->getParentRegion();
988 ops = ops.drop_front();
989 int sz = ops.size();
990 llvm::BitVector remainingOps(sz, true);
991 while (region) {
992 int pos = -1;
993 // Iterate over all remaining ops.
994 while ((pos = remainingOps.find_first_in(Begin: pos + 1, End: sz)) != -1) {
995 // Is this op contained in `region`?
996 if (region->findAncestorOpInRegion(op&: *ops[pos]))
997 remainingOps.reset(Idx: pos);
998 }
999 if (remainingOps.none())
1000 break;
1001 region = region->getParentRegion();
1002 }
1003 return region;
1004}
1005
1006LogicalResult mlir::applyOpPatternsAndFold(
1007 ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
1008 GreedyRewriteConfig config, bool *changed, bool *allErased) {
1009 if (ops.empty()) {
1010 if (changed)
1011 *changed = false;
1012 if (allErased)
1013 *allErased = true;
1014 return success();
1015 }
1016
1017 // Determine scope of rewrite.
1018 if (!config.scope) {
1019 // Compute scope if none was provided. The scope will remain `nullptr` if
1020 // there is a top-level op among `ops`.
1021 config.scope = findCommonAncestor(ops);
1022 } else {
1023 // If a scope was provided, make sure that all ops are in scope.
1024#ifndef NDEBUG
1025 bool allOpsInScope = llvm::all_of(Range&: ops, P: [&](Operation *op) {
1026 return static_cast<bool>(config.scope->findAncestorOpInRegion(op&: *op));
1027 });
1028 assert(allOpsInScope && "ops must be within the specified scope");
1029#endif // NDEBUG
1030 }
1031
1032#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1033 if (config.scope && failed(verify(config.scope->getParentOp())))
1034 llvm::report_fatal_error(
1035 "greedy pattern rewriter input IR failed to verify");
1036#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1037
1038 // Start the pattern driver.
1039 llvm::SmallDenseSet<Operation *, 4> surviving;
1040 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1041 config, ops,
1042 allErased ? &surviving : nullptr);
1043 LogicalResult converged = std::move(driver).simplify(ops, changed);
1044 if (allErased)
1045 *allErased = surviving.empty();
1046 LLVM_DEBUG(if (failed(converged)) {
1047 llvm::dbgs() << "The pattern rewrite did not converge after "
1048 << config.maxNumRewrites << " rewrites";
1049 });
1050 return converged;
1051}
1052

source code of mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp