1 | //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation pass converts operations into their canonical forms by |
10 | // folding constants, applying operation identity transformations etc. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/UB/IR/UBOps.h" |
17 | #include "mlir/Pass/Pass.h" |
18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
19 | |
20 | namespace mlir { |
21 | #define GEN_PASS_DEF_CANONICALIZER |
22 | #include "mlir/Transforms/Passes.h.inc" |
23 | } // namespace mlir |
24 | |
25 | using namespace mlir; |
26 | |
27 | namespace { |
28 | /// Canonicalize operations in nested regions. |
29 | struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> { |
30 | Canonicalizer() = default; |
31 | Canonicalizer(const GreedyRewriteConfig &config, |
32 | ArrayRef<std::string> disabledPatterns, |
33 | ArrayRef<std::string> enabledPatterns) |
34 | : config(config) { |
35 | this->topDownProcessingEnabled = config.getUseTopDownTraversal(); |
36 | this->regionSimplifyLevel = config.getRegionSimplificationLevel(); |
37 | this->maxIterations = config.getMaxIterations(); |
38 | this->maxNumRewrites = config.getMaxNumRewrites(); |
39 | this->disabledPatterns = disabledPatterns; |
40 | this->enabledPatterns = enabledPatterns; |
41 | } |
42 | |
43 | /// Initialize the canonicalizer by building the set of patterns used during |
44 | /// execution. |
45 | LogicalResult initialize(MLIRContext *context) override { |
46 | // Set the config from possible pass options set in the meantime. |
47 | config.setUseTopDownTraversal(topDownProcessingEnabled); |
48 | config.setRegionSimplificationLevel(regionSimplifyLevel); |
49 | config.setMaxIterations(maxIterations); |
50 | config.setMaxNumRewrites(maxNumRewrites); |
51 | |
52 | RewritePatternSet owningPatterns(context); |
53 | for (auto *dialect : context->getLoadedDialects()) |
54 | dialect->getCanonicalizationPatterns(owningPatterns); |
55 | for (RegisteredOperationName op : context->getRegisteredOperations()) |
56 | op.getCanonicalizationPatterns(owningPatterns, context); |
57 | |
58 | patterns = std::make_shared<FrozenRewritePatternSet>( |
59 | std::move(owningPatterns), disabledPatterns, enabledPatterns); |
60 | return success(); |
61 | } |
62 | void runOnOperation() override { |
63 | LogicalResult converged = |
64 | applyPatternsGreedily(getOperation(), *patterns, config); |
65 | // Canonicalization is best-effort. Non-convergence is not a pass failure. |
66 | if (testConvergence && failed(converged)) |
67 | signalPassFailure(); |
68 | } |
69 | GreedyRewriteConfig config; |
70 | std::shared_ptr<const FrozenRewritePatternSet> patterns; |
71 | }; |
72 | } // namespace |
73 | |
74 | /// Create a Canonicalizer pass. |
75 | std::unique_ptr<Pass> mlir::createCanonicalizerPass() { |
76 | return std::make_unique<Canonicalizer>(); |
77 | } |
78 | |
79 | /// Creates an instance of the Canonicalizer pass with the specified config. |
80 | std::unique_ptr<Pass> |
81 | mlir::createCanonicalizerPass(const GreedyRewriteConfig &config, |
82 | ArrayRef<std::string> disabledPatterns, |
83 | ArrayRef<std::string> enabledPatterns) { |
84 | return std::make_unique<Canonicalizer>(args: config, args&: disabledPatterns, |
85 | args&: enabledPatterns); |
86 | } |
87 | |