1 | //===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" |
10 | |
11 | #include "mlir/IR/BuiltinOps.h" |
12 | #include "mlir/IR/PatternMatch.h" |
13 | #include "mlir/Pass/Pass.h" |
14 | #include "mlir/Transforms/DialectConversion.h" |
15 | |
16 | namespace mlir { |
17 | #define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS |
18 | #include "mlir/Conversion/Passes.h.inc" |
19 | } // namespace mlir |
20 | |
21 | using namespace mlir; |
22 | |
23 | namespace { |
24 | |
25 | /// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types |
26 | /// the same as the input ones. |
27 | /// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A` |
28 | /// represent a noop within the IR, and thus the initial input values can be |
29 | /// propagated. |
30 | /// The same does not hold for 'open' chains of casts, such as |
31 | /// `A -> B -> C`. In this last case there is no cycle among the types and thus |
32 | /// the conversion is incomplete. The same hold for 'closed' chains like |
33 | /// `A -> B -> A`, but with the result of type `B` being used by some non-cast |
34 | /// operations. |
35 | /// Bifurcations (that is when a chain starts in between of another one) are |
36 | /// also taken into considerations, and all the above considerations remain |
37 | /// valid. |
38 | /// Special corner cases such as dead casts or single casts with same input and |
39 | /// output types are also covered. |
40 | struct UnrealizedConversionCastPassthrough |
41 | : public OpRewritePattern<UnrealizedConversionCastOp> { |
42 | using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern; |
43 | |
44 | LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, |
45 | PatternRewriter &rewriter) const override { |
46 | // The nodes that either are not used by any operation or have at least |
47 | // one user that is not an unrealized cast. |
48 | DenseSet<UnrealizedConversionCastOp> exitNodes; |
49 | |
50 | // The nodes whose users are all unrealized casts |
51 | DenseSet<UnrealizedConversionCastOp> intermediateNodes; |
52 | |
53 | // Stack used for the depth-first traversal of the use-def DAG. |
54 | SmallVector<UnrealizedConversionCastOp, 2> visitStack; |
55 | visitStack.push_back(op); |
56 | |
57 | while (!visitStack.empty()) { |
58 | UnrealizedConversionCastOp current = visitStack.pop_back_val(); |
59 | auto users = current->getUsers(); |
60 | bool isLive = false; |
61 | |
62 | for (Operation *user : users) { |
63 | if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) { |
64 | if (other.getInputs() != current.getOutputs()) |
65 | return rewriter.notifyMatchFailure( |
66 | op, "mismatching values propagation" ); |
67 | } else { |
68 | isLive = true; |
69 | } |
70 | |
71 | // Continue traversing the DAG of unrealized casts |
72 | if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) |
73 | visitStack.push_back(other); |
74 | } |
75 | |
76 | // If the cast is live, then we need to check if the results of the last |
77 | // cast have the same type of the root inputs. It this is the case (e.g. |
78 | // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a |
79 | // no-op and the inputs can be forwarded. If it's not (e.g. |
80 | // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete. |
81 | |
82 | bool isCycle = current.getResultTypes() == op.getInputs().getTypes(); |
83 | |
84 | if (isLive && !isCycle) |
85 | return rewriter.notifyMatchFailure(op, |
86 | "live unrealized conversion cast" ); |
87 | |
88 | bool isExitNode = users.empty() || isLive; |
89 | |
90 | if (isExitNode) { |
91 | exitNodes.insert(current); |
92 | } else { |
93 | intermediateNodes.insert(current); |
94 | } |
95 | } |
96 | |
97 | // Replace the sink nodes with the root input values |
98 | for (UnrealizedConversionCastOp exitNode : exitNodes) |
99 | rewriter.replaceOp(exitNode, op.getInputs()); |
100 | |
101 | // Erase all the other casts belonging to the DAG |
102 | for (UnrealizedConversionCastOp castOp : intermediateNodes) |
103 | rewriter.eraseOp(castOp); |
104 | |
105 | return success(); |
106 | } |
107 | }; |
108 | |
109 | /// Pass to simplify and eliminate unrealized conversion casts. |
110 | struct ReconcileUnrealizedCasts |
111 | : public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> { |
112 | ReconcileUnrealizedCasts() = default; |
113 | |
114 | void runOnOperation() override { |
115 | RewritePatternSet patterns(&getContext()); |
116 | populateReconcileUnrealizedCastsPatterns(patterns); |
117 | ConversionTarget target(getContext()); |
118 | target.addIllegalOp<UnrealizedConversionCastOp>(); |
119 | if (failed(applyPartialConversion(getOperation(), target, |
120 | std::move(patterns)))) |
121 | signalPassFailure(); |
122 | } |
123 | }; |
124 | |
125 | } // namespace |
126 | |
127 | void mlir::populateReconcileUnrealizedCastsPatterns( |
128 | RewritePatternSet &patterns) { |
129 | patterns.add<UnrealizedConversionCastPassthrough>(arg: patterns.getContext()); |
130 | } |
131 | |
132 | std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() { |
133 | return std::make_unique<ReconcileUnrealizedCasts>(); |
134 | } |
135 | |