1//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements mlir::applyPatternsGreedily.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14
15#include "mlir/Config/mlir-config.h"
16#include "mlir/IR/Action.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Verifier.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Rewrite/PatternApplicator.h"
21#include "mlir/Transforms/FoldUtils.h"
22#include "mlir/Transforms/RegionUtils.h"
23#include "llvm/ADT/BitVector.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/ADT/ScopeExit.h"
26#include "llvm/Support/CommandLine.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/ScopedPrinter.h"
29#include "llvm/Support/raw_ostream.h"
30
31#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
32#include <random>
33#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
34
35using namespace mlir;
36
37#define DEBUG_TYPE "greedy-rewriter"
38
39namespace {
40
41//===----------------------------------------------------------------------===//
42// Debugging Infrastructure
43//===----------------------------------------------------------------------===//
44
45#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46/// A helper struct that performs various "expensive checks" to detect broken
47/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48/// broken if:
49/// * IR does not verify after pattern application / folding.
50/// * Pattern returns "failure" but the IR has changed.
51/// * Pattern returns "success" but the IR has not changed.
52///
53/// This struct stores finger prints of ops to determine whether the IR has
54/// changed or not.
55struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56 ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57 : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
58
59 /// Compute finger prints of the given op and its nested ops.
60 void computeFingerPrints(Operation *topLevel) {
61 this->topLevel = topLevel;
62 this->topLevelFingerPrint.emplace(topLevel);
63 topLevel->walk([&](Operation *op) {
64 fingerprints.try_emplace(op, op, /*includeNested=*/false);
65 });
66 }
67
68 /// Clear all finger prints.
69 void clear() {
70 topLevel = nullptr;
71 topLevelFingerPrint.reset();
72 fingerprints.clear();
73 }
74
75 void notifyRewriteSuccess() {
76 if (!topLevel)
77 return;
78
79 // Make sure that the IR still verifies.
80 if (failed(verify(topLevel)))
81 llvm::report_fatal_error("IR failed to verify after pattern application");
82
83 // Pattern application success => IR must have changed.
84 OperationFingerPrint afterFingerPrint(topLevel);
85 if (*topLevelFingerPrint == afterFingerPrint) {
86 // Note: Run "mlir-opt -debug" to see which pattern is broken.
87 llvm::report_fatal_error(
88 "pattern returned success but IR did not change");
89 }
90 for (const auto &it : fingerprints) {
91 // Skip top-level op, its finger print is never invalidated.
92 if (it.first == topLevel)
93 continue;
94 // Note: Finger print computation may crash when an op was erased
95 // without notifying the rewriter. (Run with ASAN to see where the op was
96 // erased; the op was probably erased directly, bypassing the rewriter
97 // API.) Finger print computation does may not crash if a new op was
98 // created at the same memory location. (But then the finger print should
99 // have changed.)
100 if (it.second !=
101 OperationFingerPrint(it.first, /*includeNested=*/false)) {
102 // Note: Run "mlir-opt -debug" to see which pattern is broken.
103 llvm::report_fatal_error("operation finger print changed");
104 }
105 }
106 }
107
108 void notifyRewriteFailure() {
109 if (!topLevel)
110 return;
111
112 // Pattern application failure => IR must not have changed.
113 OperationFingerPrint afterFingerPrint(topLevel);
114 if (*topLevelFingerPrint != afterFingerPrint) {
115 // Note: Run "mlir-opt -debug" to see which pattern is broken.
116 llvm::report_fatal_error("pattern returned failure but IR did change");
117 }
118 }
119
120 void notifyFoldingSuccess() {
121 if (!topLevel)
122 return;
123
124 // Make sure that the IR still verifies.
125 if (failed(verify(topLevel)))
126 llvm::report_fatal_error("IR failed to verify after folding");
127 }
128
129protected:
130 /// Invalidate the finger print of the given op, i.e., remove it from the map.
131 void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
132
133 void notifyBlockErased(Block *block) override {
134 RewriterBase::ForwardingListener::notifyBlockErased(block);
135
136 // The block structure (number of blocks, types of block arguments, etc.)
137 // is part of the fingerprint of the parent op.
138 // TODO: The parent op fingerprint should also be invalidated when modifying
139 // the block arguments of a block, but we do not have a
140 // `notifyBlockModified` callback yet.
141 invalidateFingerPrint(block->getParentOp());
142 }
143
144 void notifyOperationInserted(Operation *op,
145 OpBuilder::InsertPoint previous) override {
146 RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
147 invalidateFingerPrint(op->getParentOp());
148 }
149
150 void notifyOperationModified(Operation *op) override {
151 RewriterBase::ForwardingListener::notifyOperationModified(op);
152 invalidateFingerPrint(op);
153 }
154
155 void notifyOperationErased(Operation *op) override {
156 RewriterBase::ForwardingListener::notifyOperationErased(op);
157 op->walk([this](Operation *op) { invalidateFingerPrint(op); });
158 }
159
160 /// Operation finger prints to detect invalid pattern API usage. IR is checked
161 /// against these finger prints after pattern application to detect cases
162 /// where IR was modified directly, bypassing the rewriter API.
163 DenseMap<Operation *, OperationFingerPrint> fingerprints;
164
165 /// Top-level operation of the current greedy rewrite.
166 Operation *topLevel = nullptr;
167
168 /// Finger print of the top-level operation.
169 std::optional<OperationFingerPrint> topLevelFingerPrint;
170};
171#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
172
173#ifndef NDEBUG
174static Operation *getDumpRootOp(Operation *op) {
175 // Dump the parent op so that materialized constants are visible. If the op
176 // is a top-level op, dump it directly.
177 if (Operation *parentOp = op->getParentOp())
178 return parentOp;
179 return op;
180}
181static void logSuccessfulFolding(Operation *op) {
182 llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
183 op->dump();
184 llvm::dbgs() << "\n\n";
185}
186#endif // NDEBUG
187
188//===----------------------------------------------------------------------===//
189// Worklist
190//===----------------------------------------------------------------------===//
191
192/// A LIFO worklist of operations with efficient removal and set semantics.
193///
194/// This class maintains a vector of operations and a mapping of operations to
195/// positions in the vector, so that operations can be removed efficiently at
196/// random. When an operation is removed, it is replaced with nullptr. Such
197/// nullptr are skipped when pop'ing elements.
198class Worklist {
199public:
200 Worklist();
201
202 /// Clear the worklist.
203 void clear();
204
205 /// Return whether the worklist is empty.
206 bool empty() const;
207
208 /// Push an operation to the end of the worklist, unless the operation is
209 /// already on the worklist.
210 void push(Operation *op);
211
212 /// Pop the an operation from the end of the worklist. Only allowed on
213 /// non-empty worklists.
214 Operation *pop();
215
216 /// Remove an operation from the worklist.
217 void remove(Operation *op);
218
219 /// Reverse the worklist.
220 void reverse();
221
222protected:
223 /// The worklist of operations.
224 std::vector<Operation *> list;
225
226 /// A mapping of operations to positions in `list`.
227 DenseMap<Operation *, unsigned> map;
228};
229
230Worklist::Worklist() { list.reserve(n: 64); }
231
232void Worklist::clear() {
233 list.clear();
234 map.clear();
235}
236
237bool Worklist::empty() const {
238 // Skip all nullptr.
239 return !llvm::any_of(Range: list,
240 P: [](Operation *op) { return static_cast<bool>(op); });
241}
242
243void Worklist::push(Operation *op) {
244 assert(op && "cannot push nullptr to worklist");
245 // Check to see if the worklist already contains this op.
246 if (!map.insert(KV: {op, list.size()}).second)
247 return;
248 list.push_back(x: op);
249}
250
251Operation *Worklist::pop() {
252 assert(!empty() && "cannot pop from empty worklist");
253 // Skip and remove all trailing nullptr.
254 while (!list.back())
255 list.pop_back();
256 Operation *op = list.back();
257 list.pop_back();
258 map.erase(Val: op);
259 // Cleanup: Remove all trailing nullptr.
260 while (!list.empty() && !list.back())
261 list.pop_back();
262 return op;
263}
264
265void Worklist::remove(Operation *op) {
266 assert(op && "cannot remove nullptr from worklist");
267 auto it = map.find(Val: op);
268 if (it != map.end()) {
269 assert(list[it->second] == op && "malformed worklist data structure");
270 list[it->second] = nullptr;
271 map.erase(I: it);
272 }
273}
274
275void Worklist::reverse() {
276 std::reverse(first: list.begin(), last: list.end());
277 for (size_t i = 0, e = list.size(); i != e; ++i)
278 map[list[i]] = i;
279}
280
281#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
282/// A worklist that pops elements at a random position. This worklist is for
283/// testing/debugging purposes only. It can be used to ensure that lowering
284/// pipelines work correctly regardless of the order in which ops are processed
285/// by the GreedyPatternRewriteDriver.
286class RandomizedWorklist : public Worklist {
287public:
288 RandomizedWorklist() : Worklist() {
289 generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
290 }
291
292 /// Pop a random non-empty op from the worklist.
293 Operation *pop() {
294 Operation *op = nullptr;
295 do {
296 assert(!list.empty() && "cannot pop from empty worklist");
297 int64_t pos = generator() % list.size();
298 op = list[pos];
299 list.erase(list.begin() + pos);
300 for (int64_t i = pos, e = list.size(); i < e; ++i)
301 map[list[i]] = i;
302 map.erase(op);
303 } while (!op);
304 return op;
305 }
306
307private:
308 std::minstd_rand0 generator;
309};
310#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
311
312//===----------------------------------------------------------------------===//
313// GreedyPatternRewriteDriver
314//===----------------------------------------------------------------------===//
315
316/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
317/// applies the locally optimal patterns.
318///
319/// This abstract class manages the worklist and contains helper methods for
320/// rewriting ops on the worklist. Derived classes specify how ops are added
321/// to the worklist in the beginning.
322class GreedyPatternRewriteDriver : public RewriterBase::Listener {
323protected:
324 explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
325 const FrozenRewritePatternSet &patterns,
326 const GreedyRewriteConfig &config);
327
328 /// Add the given operation to the worklist.
329 void addSingleOpToWorklist(Operation *op);
330
331 /// Add the given operation and its ancestors to the worklist.
332 void addToWorklist(Operation *op);
333
334 /// Notify the driver that the specified operation may have been modified
335 /// in-place. The operation is added to the worklist.
336 void notifyOperationModified(Operation *op) override;
337
338 /// Notify the driver that the specified operation was inserted. Update the
339 /// worklist as needed: The operation is enqueued depending on scope and
340 /// strict mode.
341 void notifyOperationInserted(Operation *op,
342 OpBuilder::InsertPoint previous) override;
343
344 /// Notify the driver that the specified operation was removed. Update the
345 /// worklist as needed: The operation and its children are removed from the
346 /// worklist.
347 void notifyOperationErased(Operation *op) override;
348
349 /// Notify the driver that the specified operation was replaced. Update the
350 /// worklist as needed: New users are added enqueued.
351 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
352
353 /// Process ops until the worklist is empty or `config.maxNumRewrites` is
354 /// reached. Return `true` if any IR was changed.
355 bool processWorklist();
356
357 /// The pattern rewriter that is used for making IR modifications and is
358 /// passed to rewrite patterns.
359 PatternRewriter rewriter;
360
361 /// The worklist for this transformation keeps track of the operations that
362 /// need to be (re)visited.
363#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
364 RandomizedWorklist worklist;
365#else
366 Worklist worklist;
367#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
368
369 /// Configuration information for how to simplify.
370 const GreedyRewriteConfig config;
371
372 /// The list of ops we are restricting our rewrites to. These include the
373 /// supplied set of ops as well as new ops created while rewriting those ops
374 /// depending on `strictMode`. This set is not maintained when
375 /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
376 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
377
378private:
379 /// Look over the provided operands for any defining operations that should
380 /// be re-added to the worklist. This function should be called when an
381 /// operation is modified or removed, as it may trigger further
382 /// simplifications.
383 void addOperandsToWorklist(Operation *op);
384
385 /// Notify the driver that the given block was inserted.
386 void notifyBlockInserted(Block *block, Region *previous,
387 Region::iterator previousIt) override;
388
389 /// Notify the driver that the given block is about to be removed.
390 void notifyBlockErased(Block *block) override;
391
392 /// For debugging only: Notify the driver of a pattern match failure.
393 void
394 notifyMatchFailure(Location loc,
395 function_ref<void(Diagnostic &)> reasonCallback) override;
396
397#ifndef NDEBUG
398 /// A logger used to emit information during the application process.
399 llvm::ScopedPrinter logger{llvm::dbgs()};
400#endif
401
402 /// The low-level pattern applicator.
403 PatternApplicator matcher;
404
405#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
406 ExpensiveChecks expensiveChecks;
407#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
408};
409} // namespace
410
411GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
412 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
413 const GreedyRewriteConfig &config)
414 : rewriter(ctx), config(config), matcher(patterns)
415#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
416 // clang-format off
417 , expensiveChecks(
418 /*driver=*/this,
419 /*topLevel=*/config.getScope() ? config.getScope()->getParentOp()
420 : nullptr)
421// clang-format on
422#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
423{
424 // Apply a simple cost model based solely on pattern benefit.
425 matcher.applyDefaultCostModel();
426
427 // Set up listener.
428#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
429 // Send IR notifications to the debug handler. This handler will then forward
430 // all notifications to this GreedyPatternRewriteDriver.
431 rewriter.setListener(&expensiveChecks);
432#else
433 rewriter.setListener(this);
434#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
435}
436
437bool GreedyPatternRewriteDriver::processWorklist() {
438#ifndef NDEBUG
439 const char *logLineComment =
440 "//===-------------------------------------------===//\n";
441
442 /// A utility function to log a process result for the given reason.
443 auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
444 logger.unindent();
445 logger.startLine() << "} -> " << result;
446 if (!msg.isTriviallyEmpty())
447 logger.getOStream() << " : " << msg;
448 logger.getOStream() << "\n";
449 };
450 auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
451 logResult(result, msg);
452 logger.startLine() << logLineComment;
453 };
454#endif
455
456 bool changed = false;
457 int64_t numRewrites = 0;
458 while (!worklist.empty() &&
459 (numRewrites < config.getMaxNumRewrites() ||
460 config.getMaxNumRewrites() == GreedyRewriteConfig::kNoLimit)) {
461 auto *op = worklist.pop();
462
463 LLVM_DEBUG({
464 logger.getOStream() << "\n";
465 logger.startLine() << logLineComment;
466 logger.startLine() << "Processing operation : '" << op->getName() << "'("
467 << op << ") {\n";
468 logger.indent();
469
470 // If the operation has no regions, just print it here.
471 if (op->getNumRegions() == 0) {
472 op->print(
473 logger.startLine(),
474 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
475 logger.getOStream() << "\n\n";
476 }
477 });
478
479 // If the operation is trivially dead - remove it.
480 if (isOpTriviallyDead(op)) {
481 rewriter.eraseOp(op);
482 changed = true;
483
484 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
485 continue;
486 }
487
488 // Try to fold this op. Do not fold constant ops. That would lead to an
489 // infinite folding loop, as every constant op would be folded to an
490 // Attribute and then immediately be rematerialized as a constant op, which
491 // is then put on the worklist.
492 if (config.isFoldingEnabled() && !op->hasTrait<OpTrait::ConstantLike>()) {
493 SmallVector<OpFoldResult> foldResults;
494 if (succeeded(Result: op->fold(results&: foldResults))) {
495 LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
496#ifndef NDEBUG
497 Operation *dumpRootOp = getDumpRootOp(op);
498#endif // NDEBUG
499 if (foldResults.empty()) {
500 // Op was modified in-place.
501 notifyOperationModified(op);
502 changed = true;
503 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
504#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
505 expensiveChecks.notifyFoldingSuccess();
506#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
507 continue;
508 }
509
510 // Op results can be replaced with `foldResults`.
511 assert(foldResults.size() == op->getNumResults() &&
512 "folder produced incorrect number of results");
513 OpBuilder::InsertionGuard g(rewriter);
514 rewriter.setInsertionPoint(op);
515 SmallVector<Value> replacements;
516 bool materializationSucceeded = true;
517 for (auto [ofr, resultType] :
518 llvm::zip_equal(t&: foldResults, u: op->getResultTypes())) {
519 if (auto value = dyn_cast<Value>(Val&: ofr)) {
520 assert(value.getType() == resultType &&
521 "folder produced value of incorrect type");
522 replacements.push_back(Elt: value);
523 continue;
524 }
525 // Materialize Attributes as SSA values.
526 Operation *constOp = op->getDialect()->materializeConstant(
527 builder&: rewriter, value: cast<Attribute>(Val&: ofr), type: resultType, loc: op->getLoc());
528
529 if (!constOp) {
530 // If materialization fails, cleanup any operations generated for
531 // the previous results.
532 llvm::SmallDenseSet<Operation *> replacementOps;
533 for (Value replacement : replacements) {
534 assert(replacement.use_empty() &&
535 "folder reused existing op for one result but constant "
536 "materialization failed for another result");
537 replacementOps.insert(V: replacement.getDefiningOp());
538 }
539 for (Operation *op : replacementOps) {
540 rewriter.eraseOp(op);
541 }
542
543 materializationSucceeded = false;
544 break;
545 }
546
547 assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
548 "materializeConstant produced op that is not a ConstantLike");
549 assert(constOp->getResultTypes()[0] == resultType &&
550 "materializeConstant produced incorrect result type");
551 replacements.push_back(Elt: constOp->getResult(idx: 0));
552 }
553
554 if (materializationSucceeded) {
555 rewriter.replaceOp(op, newValues: replacements);
556 changed = true;
557 LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
558#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
559 expensiveChecks.notifyFoldingSuccess();
560#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
561 continue;
562 }
563 }
564 }
565
566 // Try to match one of the patterns. The rewriter is automatically
567 // notified of any necessary changes, so there is nothing else to do
568 // here.
569 auto canApplyCallback = [&](const Pattern &pattern) {
570 LLVM_DEBUG({
571 logger.getOStream() << "\n";
572 logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
573 << op->getName() << " -> (";
574 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
575 logger.getOStream() << ")' {\n";
576 logger.indent();
577 });
578 if (RewriterBase::Listener *listener = config.getListener())
579 listener->notifyPatternBegin(pattern, op);
580 return true;
581 };
582 function_ref<bool(const Pattern &)> canApply = canApplyCallback;
583 auto onFailureCallback = [&](const Pattern &pattern) {
584 LLVM_DEBUG(logResult("failure", "pattern failed to match"));
585 if (RewriterBase::Listener *listener = config.getListener())
586 listener->notifyPatternEnd(pattern, status: failure());
587 };
588 function_ref<void(const Pattern &)> onFailure = onFailureCallback;
589 auto onSuccessCallback = [&](const Pattern &pattern) {
590 LLVM_DEBUG(logResult("success", "pattern applied successfully"));
591 if (RewriterBase::Listener *listener = config.getListener())
592 listener->notifyPatternEnd(pattern, status: success());
593 return success();
594 };
595 function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
596
597#ifdef NDEBUG
598 // Optimization: PatternApplicator callbacks are not needed when running in
599 // optimized mode and without a listener.
600 if (!config.getListener()) {
601 canApply = nullptr;
602 onFailure = nullptr;
603 onSuccess = nullptr;
604 }
605#endif // NDEBUG
606
607#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
608 if (config.getScope()) {
609 expensiveChecks.computeFingerPrints(config.getScope()->getParentOp());
610 }
611 auto clearFingerprints =
612 llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
613#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
614
615 LogicalResult matchResult =
616 matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
617
618 if (succeeded(Result: matchResult)) {
619 LLVM_DEBUG(logResultWithLine("success", "at least one pattern matched"));
620#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
621 expensiveChecks.notifyRewriteSuccess();
622#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
623 changed = true;
624 ++numRewrites;
625 } else {
626 LLVM_DEBUG(logResultWithLine("failure", "all patterns failed to match"));
627#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
628 expensiveChecks.notifyRewriteFailure();
629#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
630 }
631 }
632
633 return changed;
634}
635
636void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
637 assert(op && "expected valid op");
638 // Gather potential ancestors while looking for a "scope" parent region.
639 SmallVector<Operation *, 8> ancestors;
640 Region *region = nullptr;
641 do {
642 ancestors.push_back(Elt: op);
643 region = op->getParentRegion();
644 if (config.getScope() == region) {
645 // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
646 for (Operation *op : ancestors)
647 addSingleOpToWorklist(op);
648 return;
649 }
650 if (region == nullptr)
651 return;
652 } while ((op = region->getParentOp()));
653}
654
655void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
656 if (config.getStrictness() == GreedyRewriteStrictness::AnyOp ||
657 strictModeFilteredOps.contains(V: op))
658 worklist.push(op);
659}
660
661void GreedyPatternRewriteDriver::notifyBlockInserted(
662 Block *block, Region *previous, Region::iterator previousIt) {
663 if (RewriterBase::Listener *listener = config.getListener())
664 listener->notifyBlockInserted(block, previous, previousIt);
665}
666
667void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
668 if (RewriterBase::Listener *listener = config.getListener())
669 listener->notifyBlockErased(block);
670}
671
672void GreedyPatternRewriteDriver::notifyOperationInserted(
673 Operation *op, OpBuilder::InsertPoint previous) {
674 LLVM_DEBUG({
675 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
676 << ")\n";
677 });
678 if (RewriterBase::Listener *listener = config.getListener())
679 listener->notifyOperationInserted(op, previous);
680 if (config.getStrictness() == GreedyRewriteStrictness::ExistingAndNewOps)
681 strictModeFilteredOps.insert(V: op);
682 addToWorklist(op);
683}
684
685void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
686 LLVM_DEBUG({
687 logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
688 << ")\n";
689 });
690 if (RewriterBase::Listener *listener = config.getListener())
691 listener->notifyOperationModified(op);
692 addToWorklist(op);
693}
694
695void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
696 for (Value operand : op->getOperands()) {
697 // If this operand currently has at most 2 users, add its defining op to the
698 // worklist. Indeed, after the op is deleted, then the operand will have at
699 // most 1 user left. If it has 0 users left, it can be deleted too,
700 // and if it has 1 user left, there may be further canonicalization
701 // opportunities.
702 if (!operand)
703 continue;
704
705 auto *defOp = operand.getDefiningOp();
706 if (!defOp)
707 continue;
708
709 Operation *otherUser = nullptr;
710 bool hasMoreThanTwoUses = false;
711 for (auto user : operand.getUsers()) {
712 if (user == op || user == otherUser)
713 continue;
714 if (!otherUser) {
715 otherUser = user;
716 continue;
717 }
718 hasMoreThanTwoUses = true;
719 break;
720 }
721 if (hasMoreThanTwoUses)
722 continue;
723
724 addToWorklist(op: defOp);
725 }
726}
727
728void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
729 LLVM_DEBUG({
730 logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
731 << ")\n";
732 });
733
734#ifndef NDEBUG
735 // Only ops that are within the configured scope are added to the worklist of
736 // the greedy pattern rewriter. Moreover, the parent op of the scope region is
737 // the part of the IR that is taken into account for the "expensive checks".
738 // A greedy pattern rewrite is not allowed to erase the parent op of the scope
739 // region, as that would break the worklist handling and the expensive checks.
740 if (Region *scope = config.getScope(); scope->getParentOp() == op)
741 llvm_unreachable(
742 "scope region must not be erased during greedy pattern rewrite");
743#endif // NDEBUG
744
745 if (RewriterBase::Listener *listener = config.getListener())
746 listener->notifyOperationErased(op);
747
748 addOperandsToWorklist(op);
749 worklist.remove(op);
750
751 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
752 strictModeFilteredOps.erase(V: op);
753}
754
755void GreedyPatternRewriteDriver::notifyOperationReplaced(
756 Operation *op, ValueRange replacement) {
757 LLVM_DEBUG({
758 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
759 << ")\n";
760 });
761 if (RewriterBase::Listener *listener = config.getListener())
762 listener->notifyOperationReplaced(op, replacement);
763}
764
765void GreedyPatternRewriteDriver::notifyMatchFailure(
766 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
767 LLVM_DEBUG({
768 Diagnostic diag(loc, DiagnosticSeverity::Remark);
769 reasonCallback(diag);
770 logger.startLine() << "** Match Failure : " << diag.str() << "\n";
771 });
772 if (RewriterBase::Listener *listener = config.getListener())
773 listener->notifyMatchFailure(loc, reasonCallback);
774}
775
776//===----------------------------------------------------------------------===//
777// RegionPatternRewriteDriver
778//===----------------------------------------------------------------------===//
779
780namespace {
781/// This driver simplfies all ops in a region.
782class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
783public:
784 explicit RegionPatternRewriteDriver(MLIRContext *ctx,
785 const FrozenRewritePatternSet &patterns,
786 const GreedyRewriteConfig &config,
787 Region &regions);
788
789 /// Simplify ops inside `region` and simplify the region itself. Return
790 /// success if the transformation converged.
791 LogicalResult simplify(bool *changed) &&;
792
793private:
794 /// The region that is simplified.
795 Region &region;
796};
797} // namespace
798
799RegionPatternRewriteDriver::RegionPatternRewriteDriver(
800 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
801 const GreedyRewriteConfig &config, Region &region)
802 : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
803 // Populate strict mode ops.
804 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp) {
805 region.walk(callback: [&](Operation *op) { strictModeFilteredOps.insert(V: op); });
806 }
807}
808
809namespace {
810class GreedyPatternRewriteIteration
811 : public tracing::ActionImpl<GreedyPatternRewriteIteration> {
812public:
813 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
814 GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
815 : tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
816 iteration(iteration) {}
817 static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
818 void print(raw_ostream &os) const override {
819 os << "GreedyPatternRewriteIteration(" << iteration << ")";
820 }
821
822private:
823 int64_t iteration = 0;
824};
825} // namespace
826
827LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
828 bool continueRewrites = false;
829 int64_t iteration = 0;
830 MLIRContext *ctx = rewriter.getContext();
831 do {
832 // Check if the iteration limit was reached.
833 if (++iteration > config.getMaxIterations() &&
834 config.getMaxIterations() != GreedyRewriteConfig::kNoLimit)
835 break;
836
837 // New iteration: start with an empty worklist.
838 worklist.clear();
839
840 // `OperationFolder` CSE's constant ops (and may move them into parents
841 // regions to enable more aggressive CSE'ing).
842 OperationFolder folder(ctx, this);
843 auto insertKnownConstant = [&](Operation *op) {
844 // Check for existing constants when populating the worklist. This avoids
845 // accidentally reversing the constant order during processing.
846 Attribute constValue;
847 if (matchPattern(op, pattern: m_Constant(bind_value: &constValue)))
848 if (!folder.insertKnownConstant(op, constValue))
849 return true;
850 return false;
851 };
852
853 if (!config.getUseTopDownTraversal()) {
854 // Add operations to the worklist in postorder.
855 region.walk(callback: [&](Operation *op) {
856 if (!config.isConstantCSEEnabled() || !insertKnownConstant(op))
857 addToWorklist(op);
858 });
859 } else {
860 // Add all nested operations to the worklist in preorder.
861 region.walk<WalkOrder::PreOrder>(callback: [&](Operation *op) {
862 if (!config.isConstantCSEEnabled() || !insertKnownConstant(op)) {
863 addToWorklist(op);
864 return WalkResult::advance();
865 }
866 return WalkResult::skip();
867 });
868
869 // Reverse the list so our pop-back loop processes them in-order.
870 worklist.reverse();
871 }
872
873 ctx->executeAction<GreedyPatternRewriteIteration>(
874 actionFn: [&] {
875 continueRewrites = processWorklist();
876
877 // After applying patterns, make sure that the CFG of each of the
878 // regions is kept up to date.
879 if (config.getRegionSimplificationLevel() !=
880 GreedySimplifyRegionLevel::Disabled) {
881 continueRewrites |= succeeded(Result: simplifyRegions(
882 rewriter, regions: region,
883 /*mergeBlocks=*/config.getRegionSimplificationLevel() ==
884 GreedySimplifyRegionLevel::Aggressive));
885 }
886 },
887 irUnits: {&region}, args&: iteration);
888 } while (continueRewrites);
889
890 if (changed)
891 *changed = iteration > 1;
892
893 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
894 return success(IsSuccess: !continueRewrites);
895}
896
897LogicalResult
898mlir::applyPatternsGreedily(Region &region,
899 const FrozenRewritePatternSet &patterns,
900 GreedyRewriteConfig config, bool *changed) {
901 // The top-level operation must be known to be isolated from above to
902 // prevent performing canonicalizations on operations defined at or above
903 // the region containing 'op'.
904 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
905 "patterns can only be applied to operations IsolatedFromAbove");
906
907 // Set scope if not specified.
908 if (!config.getScope())
909 config.setScope(&region);
910
911#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
912 if (failed(verify(config.getScope()->getParentOp())))
913 llvm::report_fatal_error(
914 "greedy pattern rewriter input IR failed to verify");
915#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
916
917 // Start the pattern driver.
918 RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
919 region);
920 LogicalResult converged = std::move(driver).simplify(changed);
921 LLVM_DEBUG(if (failed(converged)) {
922 llvm::dbgs() << "The pattern rewrite did not converge after scanning "
923 << config.getMaxIterations() << " times\n";
924 });
925 return converged;
926}
927
928//===----------------------------------------------------------------------===//
929// MultiOpPatternRewriteDriver
930//===----------------------------------------------------------------------===//
931
932namespace {
933/// This driver simplfies a list of ops.
934class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
935public:
936 explicit MultiOpPatternRewriteDriver(
937 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
938 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
939 llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
940
941 /// Simplify `ops`. Return `success` if the transformation converged.
942 LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
943
944private:
945 void notifyOperationErased(Operation *op) override {
946 GreedyPatternRewriteDriver::notifyOperationErased(op);
947 if (survivingOps)
948 survivingOps->erase(V: op);
949 }
950
951 /// An optional set of ops that survived the rewrite. This set is populated
952 /// at the beginning of `simplifyLocally` with the inititally provided list
953 /// of ops.
954 llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
955};
956} // namespace
957
958MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
959 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
960 const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
961 llvm::SmallDenseSet<Operation *, 4> *survivingOps)
962 : GreedyPatternRewriteDriver(ctx, patterns, config),
963 survivingOps(survivingOps) {
964 if (config.getStrictness() != GreedyRewriteStrictness::AnyOp)
965 strictModeFilteredOps.insert_range(R&: ops);
966
967 if (survivingOps) {
968 survivingOps->clear();
969 survivingOps->insert_range(R&: ops);
970 }
971}
972
973LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
974 bool *changed) && {
975 // Populate the initial worklist.
976 for (Operation *op : ops)
977 addSingleOpToWorklist(op);
978
979 // Process ops on the worklist.
980 bool result = processWorklist();
981 if (changed)
982 *changed = result;
983
984 return success(IsSuccess: worklist.empty());
985}
986
987/// Find the region that is the closest common ancestor of all given ops.
988///
989/// Note: This function returns `nullptr` if there is a top-level op among the
990/// given list of ops.
991static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
992 assert(!ops.empty() && "expected at least one op");
993 // Fast path in case there is only one op.
994 if (ops.size() == 1)
995 return ops.front()->getParentRegion();
996
997 Region *region = ops.front()->getParentRegion();
998 ops = ops.drop_front();
999 int sz = ops.size();
1000 llvm::BitVector remainingOps(sz, true);
1001 while (region) {
1002 int pos = -1;
1003 // Iterate over all remaining ops.
1004 while ((pos = remainingOps.find_first_in(Begin: pos + 1, End: sz)) != -1) {
1005 // Is this op contained in `region`?
1006 if (region->findAncestorOpInRegion(op&: *ops[pos]))
1007 remainingOps.reset(Idx: pos);
1008 }
1009 if (remainingOps.none())
1010 break;
1011 region = region->getParentRegion();
1012 }
1013 return region;
1014}
1015
1016LogicalResult mlir::applyOpPatternsGreedily(
1017 ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
1018 GreedyRewriteConfig config, bool *changed, bool *allErased) {
1019 if (ops.empty()) {
1020 if (changed)
1021 *changed = false;
1022 if (allErased)
1023 *allErased = true;
1024 return success();
1025 }
1026
1027 // Determine scope of rewrite.
1028 if (!config.getScope()) {
1029 // Compute scope if none was provided. The scope will remain `nullptr` if
1030 // there is a top-level op among `ops`.
1031 config.setScope(findCommonAncestor(ops));
1032 } else {
1033 // If a scope was provided, make sure that all ops are in scope.
1034#ifndef NDEBUG
1035 bool allOpsInScope = llvm::all_of(Range&: ops, P: [&](Operation *op) {
1036 return static_cast<bool>(config.getScope()->findAncestorOpInRegion(op&: *op));
1037 });
1038 assert(allOpsInScope && "ops must be within the specified scope");
1039#endif // NDEBUG
1040 }
1041
1042#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1043 if (config.getScope() && failed(verify(config.getScope()->getParentOp())))
1044 llvm::report_fatal_error(
1045 "greedy pattern rewriter input IR failed to verify");
1046#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1047
1048 // Start the pattern driver.
1049 llvm::SmallDenseSet<Operation *, 4> surviving;
1050 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
1051 config, ops,
1052 allErased ? &surviving : nullptr);
1053 LogicalResult converged = std::move(driver).simplify(ops, changed);
1054 if (allErased)
1055 *allErased = surviving.empty();
1056 LLVM_DEBUG(if (failed(converged)) {
1057 llvm::dbgs() << "The pattern rewrite did not converge after "
1058 << config.getMaxNumRewrites() << " rewrites";
1059 });
1060 return converged;
1061}
1062

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp