1 | //===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Shape/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/Dialect/Shape/IR/Shape.h" |
13 | #include "mlir/Transforms/DialectConversion.h" |
14 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
15 | |
16 | namespace mlir { |
17 | #define GEN_PASS_DEF_REMOVESHAPECONSTRAINTS |
18 | #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" |
19 | } // namespace mlir |
20 | |
21 | using namespace mlir; |
22 | |
23 | namespace { |
24 | /// Removal patterns. |
25 | class RemoveCstrBroadcastableOp |
26 | : public OpRewritePattern<shape::CstrBroadcastableOp> { |
27 | public: |
28 | using OpRewritePattern::OpRewritePattern; |
29 | |
30 | LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, |
31 | PatternRewriter &rewriter) const override { |
32 | rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); |
33 | return success(); |
34 | } |
35 | }; |
36 | |
37 | class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> { |
38 | public: |
39 | using OpRewritePattern::OpRewritePattern; |
40 | |
41 | LogicalResult matchAndRewrite(shape::CstrEqOp op, |
42 | PatternRewriter &rewriter) const override { |
43 | rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true); |
44 | return success(); |
45 | } |
46 | }; |
47 | |
48 | /// Removal pass. |
49 | class RemoveShapeConstraintsPass |
50 | : public impl::RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> { |
51 | |
52 | void runOnOperation() override { |
53 | MLIRContext &ctx = getContext(); |
54 | |
55 | RewritePatternSet patterns(&ctx); |
56 | populateRemoveShapeConstraintsPatterns(patterns); |
57 | |
58 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
59 | } |
60 | }; |
61 | |
62 | } // namespace |
63 | |
64 | void mlir::populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns) { |
65 | patterns.add<RemoveCstrBroadcastableOp, RemoveCstrEqOp>( |
66 | arg: patterns.getContext()); |
67 | } |
68 | |
69 | std::unique_ptr<OperationPass<func::FuncOp>> |
70 | mlir::createRemoveShapeConstraintsPass() { |
71 | return std::make_unique<RemoveShapeConstraintsPass>(); |
72 | } |
73 |