1//===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Dialect/FIRDialect.h"
10#include "flang/Optimizer/Dialect/FIROps.h"
11#include "flang/Optimizer/Dialect/FIRType.h"
12#include "flang/Optimizer/Transforms/Passes.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/Transforms/Passes.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallVector.h"
19#include <optional>
20
21namespace fir {
22#define GEN_PASS_DEF_MEMREFDATAFLOWOPT
23#include "flang/Optimizer/Transforms/Passes.h.inc"
24} // namespace fir
25
26#define DEBUG_TYPE "fir-memref-dataflow-opt"
27
28using namespace mlir;
29
30namespace {
31
32template <typename OpT>
33static std::vector<OpT> getSpecificUsers(mlir::Value v) {
34 std::vector<OpT> ops;
35 for (mlir::Operation *user : v.getUsers())
36 if (auto op = dyn_cast<OpT>(user))
37 ops.push_back(op);
38 return ops;
39}
40
41/// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
42/// and AffineWrite interface
43template <typename ReadOp, typename WriteOp>
44class LoadStoreForwarding {
45public:
46 LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {}
47
48 // FIXME: This algorithm has a bug. It ignores escaping references between a
49 // store and a load.
50 std::optional<WriteOp> findStoreToForward(ReadOp loadOp,
51 std::vector<WriteOp> &&storeOps) {
52 llvm::SmallVector<WriteOp> candidateSet;
53
54 for (auto storeOp : storeOps)
55 if (domInfo->dominates(storeOp, loadOp))
56 candidateSet.push_back(storeOp);
57
58 if (candidateSet.empty())
59 return {};
60
61 std::optional<WriteOp> nearestStore;
62 for (auto candidate : candidateSet) {
63 auto nearerThan = [&](WriteOp otherStore) {
64 if (candidate == otherStore)
65 return false;
66 bool rv = domInfo->properlyDominates(candidate, otherStore);
67 if (rv) {
68 LLVM_DEBUG(llvm::dbgs()
69 << "candidate " << candidate << " is not the nearest to "
70 << loadOp << " because " << otherStore << " is closer\n");
71 }
72 return rv;
73 };
74 if (!llvm::any_of(candidateSet, nearerThan)) {
75 nearestStore = mlir::cast<WriteOp>(candidate);
76 break;
77 }
78 }
79 if (!nearestStore) {
80 LLVM_DEBUG(
81 llvm::dbgs()
82 << "load " << loadOp << " has " << candidateSet.size()
83 << " store candidates, but this algorithm can't find a best.\n");
84 }
85 return nearestStore;
86 }
87
88 std::optional<ReadOp> findReadForWrite(WriteOp storeOp,
89 std::vector<ReadOp> &&loadOps) {
90 for (auto &loadOp : loadOps) {
91 if (domInfo->dominates(storeOp, loadOp))
92 return loadOp;
93 }
94 return {};
95 }
96
97private:
98 mlir::DominanceInfo *domInfo;
99};
100
101class MemDataFlowOpt : public fir::impl::MemRefDataFlowOptBase<MemDataFlowOpt> {
102public:
103 void runOnOperation() override {
104 mlir::func::FuncOp f = getOperation();
105
106 auto *domInfo = &getAnalysis<mlir::DominanceInfo>();
107 LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo);
108 f.walk([&](fir::LoadOp loadOp) {
109 auto maybeStore = lsf.findStoreToForward(
110 loadOp, getSpecificUsers<fir::StoreOp>(loadOp.getMemref()));
111 if (maybeStore) {
112 auto storeOp = *maybeStore;
113 LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
114 << " erasing load " << loadOp
115 << " with value from " << storeOp << '\n');
116 loadOp.getResult().replaceAllUsesWith(storeOp.getValue());
117 loadOp.erase();
118 }
119 });
120 f.walk([&](fir::AllocaOp alloca) {
121 for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) {
122 if (!lsf.findReadForWrite(
123 storeOp, getSpecificUsers<fir::LoadOp>(storeOp.getMemref()))) {
124 LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
125 << " erasing store " << storeOp << '\n');
126 storeOp.erase();
127 }
128 }
129 });
130 }
131};
132} // namespace
133
134std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() {
135 return std::make_unique<MemDataFlowOpt>();
136}
137

source code of flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp