1//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass converts operations into their canonical forms by
10// folding constants, applying operation identity transformations etc.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/Passes.h"
15
16#include "mlir/Dialect/UB/IR/UBOps.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20namespace mlir {
21#define GEN_PASS_DEF_CANONICALIZER
22#include "mlir/Transforms/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26
27namespace {
28/// Canonicalize operations in nested regions.
29struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
30 Canonicalizer() = default;
31 Canonicalizer(const GreedyRewriteConfig &config,
32 ArrayRef<std::string> disabledPatterns,
33 ArrayRef<std::string> enabledPatterns)
34 : config(config) {
35 this->topDownProcessingEnabled = config.getUseTopDownTraversal();
36 this->regionSimplifyLevel = config.getRegionSimplificationLevel();
37 this->maxIterations = config.getMaxIterations();
38 this->maxNumRewrites = config.getMaxNumRewrites();
39 this->disabledPatterns = disabledPatterns;
40 this->enabledPatterns = enabledPatterns;
41 }
42
43 /// Initialize the canonicalizer by building the set of patterns used during
44 /// execution.
45 LogicalResult initialize(MLIRContext *context) override {
46 // Set the config from possible pass options set in the meantime.
47 config.setUseTopDownTraversal(topDownProcessingEnabled);
48 config.setRegionSimplificationLevel(regionSimplifyLevel);
49 config.setMaxIterations(maxIterations);
50 config.setMaxNumRewrites(maxNumRewrites);
51
52 RewritePatternSet owningPatterns(context);
53 for (auto *dialect : context->getLoadedDialects())
54 dialect->getCanonicalizationPatterns(owningPatterns);
55 for (RegisteredOperationName op : context->getRegisteredOperations())
56 op.getCanonicalizationPatterns(owningPatterns, context);
57
58 patterns = std::make_shared<FrozenRewritePatternSet>(
59 std::move(owningPatterns), disabledPatterns, enabledPatterns);
60 return success();
61 }
62 void runOnOperation() override {
63 LogicalResult converged =
64 applyPatternsGreedily(getOperation(), *patterns, config);
65 // Canonicalization is best-effort. Non-convergence is not a pass failure.
66 if (testConvergence && failed(converged))
67 signalPassFailure();
68 }
69 GreedyRewriteConfig config;
70 std::shared_ptr<const FrozenRewritePatternSet> patterns;
71};
72} // namespace
73
74/// Create a Canonicalizer pass.
75std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
76 return std::make_unique<Canonicalizer>();
77}
78
79/// Creates an instance of the Canonicalizer pass with the specified config.
80std::unique_ptr<Pass>
81mlir::createCanonicalizerPass(const GreedyRewriteConfig &config,
82 ArrayRef<std::string> disabledPatterns,
83 ArrayRef<std::string> enabledPatterns) {
84 return std::make_unique<Canonicalizer>(args: config, args&: disabledPatterns,
85 args&: enabledPatterns);
86}
87

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Transforms/Canonicalizer.cpp