1//===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
10
11#include "mlir/IR/BuiltinOps.h"
12#include "mlir/IR/PatternMatch.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
15
16namespace mlir {
17#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
18#include "mlir/Conversion/Passes.h.inc"
19} // namespace mlir
20
21using namespace mlir;
22
23namespace {
24
25/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
26/// the same as the input ones.
27/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
28/// represent a noop within the IR, and thus the initial input values can be
29/// propagated.
30/// The same does not hold for 'open' chains of casts, such as
31/// `A -> B -> C`. In this last case there is no cycle among the types and thus
32/// the conversion is incomplete. The same hold for 'closed' chains like
33/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
34/// operations.
35/// Bifurcations (that is when a chain starts in between of another one) are
36/// also taken into considerations, and all the above considerations remain
37/// valid.
38/// Special corner cases such as dead casts or single casts with same input and
39/// output types are also covered.
40struct UnrealizedConversionCastPassthrough
41 : public OpRewritePattern<UnrealizedConversionCastOp> {
42 using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
43
44 LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
45 PatternRewriter &rewriter) const override {
46 // The nodes that either are not used by any operation or have at least
47 // one user that is not an unrealized cast.
48 DenseSet<UnrealizedConversionCastOp> exitNodes;
49
50 // The nodes whose users are all unrealized casts
51 DenseSet<UnrealizedConversionCastOp> intermediateNodes;
52
53 // Stack used for the depth-first traversal of the use-def DAG.
54 SmallVector<UnrealizedConversionCastOp, 2> visitStack;
55 visitStack.push_back(op);
56
57 while (!visitStack.empty()) {
58 UnrealizedConversionCastOp current = visitStack.pop_back_val();
59 auto users = current->getUsers();
60 bool isLive = false;
61
62 for (Operation *user : users) {
63 if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
64 if (other.getInputs() != current.getOutputs())
65 return rewriter.notifyMatchFailure(
66 op, "mismatching values propagation");
67 } else {
68 isLive = true;
69 }
70
71 // Continue traversing the DAG of unrealized casts
72 if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
73 visitStack.push_back(other);
74 }
75
76 // If the cast is live, then we need to check if the results of the last
77 // cast have the same type of the root inputs. It this is the case (e.g.
78 // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
79 // no-op and the inputs can be forwarded. If it's not (e.g.
80 // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
81
82 bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
83
84 if (isLive && !isCycle)
85 return rewriter.notifyMatchFailure(op,
86 "live unrealized conversion cast");
87
88 bool isExitNode = users.empty() || isLive;
89
90 if (isExitNode) {
91 exitNodes.insert(current);
92 } else {
93 intermediateNodes.insert(current);
94 }
95 }
96
97 // Replace the sink nodes with the root input values
98 for (UnrealizedConversionCastOp exitNode : exitNodes)
99 rewriter.replaceOp(exitNode, op.getInputs());
100
101 // Erase all the other casts belonging to the DAG
102 for (UnrealizedConversionCastOp castOp : intermediateNodes)
103 rewriter.eraseOp(castOp);
104
105 return success();
106 }
107};
108
109/// Pass to simplify and eliminate unrealized conversion casts.
110struct ReconcileUnrealizedCasts
111 : public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
112 ReconcileUnrealizedCasts() = default;
113
114 void runOnOperation() override {
115 RewritePatternSet patterns(&getContext());
116 populateReconcileUnrealizedCastsPatterns(patterns);
117 ConversionTarget target(getContext());
118 target.addIllegalOp<UnrealizedConversionCastOp>();
119 if (failed(applyPartialConversion(getOperation(), target,
120 std::move(patterns))))
121 signalPassFailure();
122 }
123};
124
125} // namespace
126
127void mlir::populateReconcileUnrealizedCastsPatterns(
128 RewritePatternSet &patterns) {
129 patterns.add<UnrealizedConversionCastPassthrough>(arg: patterns.getContext());
130}
131
132std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
133 return std::make_unique<ReconcileUnrealizedCasts>();
134}
135

source code of mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp