1//===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Shape/Transforms/Passes.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/Shape/IR/Shape.h"
13#include "mlir/Transforms/DialectConversion.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15
16namespace mlir {
17#define GEN_PASS_DEF_REMOVESHAPECONSTRAINTS
18#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
19} // namespace mlir
20
21using namespace mlir;
22
23namespace {
24/// Removal patterns.
25class RemoveCstrBroadcastableOp
26 : public OpRewritePattern<shape::CstrBroadcastableOp> {
27public:
28 using OpRewritePattern::OpRewritePattern;
29
30 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
31 PatternRewriter &rewriter) const override {
32 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
33 return success();
34 }
35};
36
37class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
38public:
39 using OpRewritePattern::OpRewritePattern;
40
41 LogicalResult matchAndRewrite(shape::CstrEqOp op,
42 PatternRewriter &rewriter) const override {
43 rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
44 return success();
45 }
46};
47
48/// Removal pass.
49class RemoveShapeConstraintsPass
50 : public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
51
52 void runOnOperation() override {
53 MLIRContext &ctx = getContext();
54
55 RewritePatternSet patterns(&ctx);
56 populateRemoveShapeConstraintsPatterns(patterns);
57
58 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
59 }
60};
61
62} // namespace
63
64void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) {
65 patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
66 arg: patterns.getContext());
67}
68
69std::unique_ptr<OperationPass<func::FuncOp>>
70mlir::createRemoveShapeConstraintsPass() {
71 return std::make_unique<RemoveShapeConstraintsPass>();
72}
73

source code of mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp