1 | //===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
10 | #include "flang/Optimizer/Dialect/FIROps.h" |
11 | #include "flang/Optimizer/Dialect/FIRType.h" |
12 | #include "flang/Optimizer/Transforms/Passes.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/IR/Dominance.h" |
15 | #include "mlir/IR/Operation.h" |
16 | #include "mlir/Transforms/Passes.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include <optional> |
20 | |
21 | namespace fir { |
22 | #define GEN_PASS_DEF_MEMREFDATAFLOWOPT |
23 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
24 | } // namespace fir |
25 | |
26 | #define DEBUG_TYPE "fir-memref-dataflow-opt" |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | |
32 | template <typename OpT> |
33 | static std::vector<OpT> getSpecificUsers(mlir::Value v) { |
34 | std::vector<OpT> ops; |
35 | for (mlir::Operation *user : v.getUsers()) |
36 | if (auto op = dyn_cast<OpT>(user)) |
37 | ops.push_back(op); |
38 | return ops; |
39 | } |
40 | |
41 | /// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead |
42 | /// and AffineWrite interface |
43 | template <typename ReadOp, typename WriteOp> |
44 | class LoadStoreForwarding { |
45 | public: |
46 | LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {} |
47 | |
48 | // FIXME: This algorithm has a bug. It ignores escaping references between a |
49 | // store and a load. |
50 | std::optional<WriteOp> findStoreToForward(ReadOp loadOp, |
51 | std::vector<WriteOp> &&storeOps) { |
52 | llvm::SmallVector<WriteOp> candidateSet; |
53 | |
54 | for (auto storeOp : storeOps) |
55 | if (domInfo->dominates(storeOp, loadOp)) |
56 | candidateSet.push_back(storeOp); |
57 | |
58 | if (candidateSet.empty()) |
59 | return {}; |
60 | |
61 | std::optional<WriteOp> nearestStore; |
62 | for (auto candidate : candidateSet) { |
63 | auto nearerThan = [&](WriteOp otherStore) { |
64 | if (candidate == otherStore) |
65 | return false; |
66 | bool rv = domInfo->properlyDominates(candidate, otherStore); |
67 | if (rv) { |
68 | LLVM_DEBUG(llvm::dbgs() |
69 | << "candidate " << candidate << " is not the nearest to " |
70 | << loadOp << " because " << otherStore << " is closer\n" ); |
71 | } |
72 | return rv; |
73 | }; |
74 | if (!llvm::any_of(candidateSet, nearerThan)) { |
75 | nearestStore = mlir::cast<WriteOp>(candidate); |
76 | break; |
77 | } |
78 | } |
79 | if (!nearestStore) { |
80 | LLVM_DEBUG( |
81 | llvm::dbgs() |
82 | << "load " << loadOp << " has " << candidateSet.size() |
83 | << " store candidates, but this algorithm can't find a best.\n" ); |
84 | } |
85 | return nearestStore; |
86 | } |
87 | |
88 | std::optional<ReadOp> findReadForWrite(WriteOp storeOp, |
89 | std::vector<ReadOp> &&loadOps) { |
90 | for (auto &loadOp : loadOps) { |
91 | if (domInfo->dominates(storeOp, loadOp)) |
92 | return loadOp; |
93 | } |
94 | return {}; |
95 | } |
96 | |
97 | private: |
98 | mlir::DominanceInfo *domInfo; |
99 | }; |
100 | |
101 | class MemDataFlowOpt : public fir::impl::MemRefDataFlowOptBase<MemDataFlowOpt> { |
102 | public: |
103 | void runOnOperation() override { |
104 | mlir::func::FuncOp f = getOperation(); |
105 | |
106 | auto *domInfo = &getAnalysis<mlir::DominanceInfo>(); |
107 | LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo); |
108 | f.walk([&](fir::LoadOp loadOp) { |
109 | auto maybeStore = lsf.findStoreToForward( |
110 | loadOp, getSpecificUsers<fir::StoreOp>(loadOp.getMemref())); |
111 | if (maybeStore) { |
112 | auto storeOp = *maybeStore; |
113 | LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName() |
114 | << " erasing load " << loadOp |
115 | << " with value from " << storeOp << '\n'); |
116 | loadOp.getResult().replaceAllUsesWith(storeOp.getValue()); |
117 | loadOp.erase(); |
118 | } |
119 | }); |
120 | f.walk([&](fir::AllocaOp alloca) { |
121 | for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) { |
122 | if (!lsf.findReadForWrite( |
123 | storeOp, getSpecificUsers<fir::LoadOp>(storeOp.getMemref()))) { |
124 | LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName() |
125 | << " erasing store " << storeOp << '\n'); |
126 | storeOp.erase(); |
127 | } |
128 | } |
129 | }); |
130 | } |
131 | }; |
132 | } // namespace |
133 | |
134 | std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() { |
135 | return std::make_unique<MemDataFlowOpt>(); |
136 | } |
137 | |