1 | //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation pass converts operations into their canonical forms by |
10 | // folding constants, applying operation identity transformations etc. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Pass/Pass.h" |
17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
18 | |
19 | namespace mlir { |
20 | #define GEN_PASS_DEF_CANONICALIZER |
21 | #include "mlir/Transforms/Passes.h.inc" |
22 | } // namespace mlir |
23 | |
24 | using namespace mlir; |
25 | |
26 | namespace { |
27 | /// Canonicalize operations in nested regions. |
28 | struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> { |
29 | Canonicalizer() = default; |
30 | Canonicalizer(const GreedyRewriteConfig &config, |
31 | ArrayRef<std::string> disabledPatterns, |
32 | ArrayRef<std::string> enabledPatterns) |
33 | : config(config) { |
34 | this->topDownProcessingEnabled = config.useTopDownTraversal; |
35 | this->enableRegionSimplification = config.enableRegionSimplification; |
36 | this->maxIterations = config.maxIterations; |
37 | this->maxNumRewrites = config.maxNumRewrites; |
38 | this->disabledPatterns = disabledPatterns; |
39 | this->enabledPatterns = enabledPatterns; |
40 | } |
41 | |
42 | /// Initialize the canonicalizer by building the set of patterns used during |
43 | /// execution. |
44 | LogicalResult initialize(MLIRContext *context) override { |
45 | // Set the config from possible pass options set in the meantime. |
46 | config.useTopDownTraversal = topDownProcessingEnabled; |
47 | config.enableRegionSimplification = enableRegionSimplification; |
48 | config.maxIterations = maxIterations; |
49 | config.maxNumRewrites = maxNumRewrites; |
50 | |
51 | RewritePatternSet owningPatterns(context); |
52 | for (auto *dialect : context->getLoadedDialects()) |
53 | dialect->getCanonicalizationPatterns(owningPatterns); |
54 | for (RegisteredOperationName op : context->getRegisteredOperations()) |
55 | op.getCanonicalizationPatterns(owningPatterns, context); |
56 | |
57 | patterns = std::make_shared<FrozenRewritePatternSet>( |
58 | std::move(owningPatterns), disabledPatterns, enabledPatterns); |
59 | return success(); |
60 | } |
61 | void runOnOperation() override { |
62 | LogicalResult converged = |
63 | applyPatternsAndFoldGreedily(getOperation(), *patterns, config); |
64 | // Canonicalization is best-effort. Non-convergence is not a pass failure. |
65 | if (testConvergence && failed(converged)) |
66 | signalPassFailure(); |
67 | } |
68 | GreedyRewriteConfig config; |
69 | std::shared_ptr<const FrozenRewritePatternSet> patterns; |
70 | }; |
71 | } // namespace |
72 | |
73 | /// Create a Canonicalizer pass. |
74 | std::unique_ptr<Pass> mlir::createCanonicalizerPass() { |
75 | return std::make_unique<Canonicalizer>(); |
76 | } |
77 | |
78 | /// Creates an instance of the Canonicalizer pass with the specified config. |
79 | std::unique_ptr<Pass> |
80 | mlir::createCanonicalizerPass(const GreedyRewriteConfig &config, |
81 | ArrayRef<std::string> disabledPatterns, |
82 | ArrayRef<std::string> enabledPatterns) { |
83 | return std::make_unique<Canonicalizer>(args: config, args&: disabledPatterns, |
84 | args&: enabledPatterns); |
85 | } |
86 | |