| 1 | //===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Implementation of tests passes exercising dense forward data flow analysis. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "TestDenseDataFlowAnalysis.h" |
| 14 | #include "TestDialect.h" |
| 15 | #include "TestOps.h" |
| 16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
| 17 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
| 18 | #include "mlir/Analysis/DataFlow/DenseAnalysis.h" |
| 19 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 20 | #include "mlir/Pass/Pass.h" |
| 21 | #include "mlir/Support/LLVM.h" |
| 22 | #include "llvm/ADT/TypeSwitch.h" |
| 23 | #include <optional> |
| 24 | |
| 25 | using namespace mlir; |
| 26 | using namespace mlir::dataflow; |
| 27 | using namespace mlir::dataflow::test; |
| 28 | |
| 29 | namespace { |
| 30 | |
| 31 | /// This lattice represents, for a given memory resource, the potential last |
| 32 | /// operations that modified the resource. |
| 33 | class LastModification : public AbstractDenseLattice, public AccessLatticeBase { |
| 34 | public: |
| 35 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) |
| 36 | |
| 37 | using AbstractDenseLattice::AbstractDenseLattice; |
| 38 | |
| 39 | /// Join the last modifications. |
| 40 | ChangeResult join(const AbstractDenseLattice &lattice) override { |
| 41 | return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>( |
| 42 | static_cast<const LastModification &>(lattice))); |
| 43 | } |
| 44 | |
| 45 | void print(raw_ostream &os) const override { |
| 46 | return AccessLatticeBase::print(os); |
| 47 | } |
| 48 | }; |
| 49 | |
| 50 | class LastModifiedAnalysis |
| 51 | : public DenseForwardDataFlowAnalysis<LastModification> { |
| 52 | public: |
| 53 | explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites) |
| 54 | : DenseForwardDataFlowAnalysis(solver), |
| 55 | assumeFuncWrites(assumeFuncWrites) {} |
| 56 | |
| 57 | /// Visit an operation. If the operation has no memory effects, then the state |
| 58 | /// is propagated with no change. If the operation allocates a resource, then |
| 59 | /// its reaching definitions is set to empty. If the operation writes to a |
| 60 | /// resource, then its reaching definition is set to the written value. |
| 61 | LogicalResult visitOperation(Operation *op, const LastModification &before, |
| 62 | LastModification *after) override; |
| 63 | |
| 64 | void visitCallControlFlowTransfer(CallOpInterface call, |
| 65 | CallControlFlowAction action, |
| 66 | const LastModification &before, |
| 67 | LastModification *after) override; |
| 68 | |
| 69 | void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, |
| 70 | std::optional<unsigned> regionFrom, |
| 71 | std::optional<unsigned> regionTo, |
| 72 | const LastModification &before, |
| 73 | LastModification *after) override; |
| 74 | |
| 75 | /// Visit an operation. If this analysis can confirm that lattice content |
| 76 | /// of lattice anchors around operation are necessarily identical, join |
| 77 | /// them into the same equivalent class. |
| 78 | void buildOperationEquivalentLatticeAnchor(Operation *op) override; |
| 79 | |
| 80 | /// At an entry point, the last modifications of all memory resources are |
| 81 | /// unknown. |
| 82 | void setToEntryState(LastModification *lattice) override { |
| 83 | propagateIfChanged(state: lattice, changed: lattice->reset()); |
| 84 | } |
| 85 | |
| 86 | private: |
| 87 | const bool assumeFuncWrites; |
| 88 | }; |
| 89 | } // end anonymous namespace |
| 90 | |
| 91 | LogicalResult LastModifiedAnalysis::visitOperation( |
| 92 | Operation *op, const LastModification &before, LastModification *after) { |
| 93 | auto memory = dyn_cast<MemoryEffectOpInterface>(op); |
| 94 | // If we can't reason about the memory effects, then conservatively assume we |
| 95 | // can't deduce anything about the last modifications. |
| 96 | if (!memory) { |
| 97 | setToEntryState(after); |
| 98 | return success(); |
| 99 | } |
| 100 | |
| 101 | SmallVector<MemoryEffects::EffectInstance> effects; |
| 102 | memory.getEffects(effects); |
| 103 | |
| 104 | // First, check if all underlying values are already known. Otherwise, avoid |
| 105 | // propagating and stay in the "undefined" state to avoid incorrectly |
| 106 | // propagating values that may be overwritten later on as that could be |
| 107 | // problematic for convergence based on monotonicity of lattice updates. |
| 108 | SmallVector<Value> underlyingValues; |
| 109 | underlyingValues.reserve(N: effects.size()); |
| 110 | for (const auto &effect : effects) { |
| 111 | Value value = effect.getValue(); |
| 112 | |
| 113 | // If we see an effect on anything other than a value, assume we can't |
| 114 | // deduce anything about the last modifications. |
| 115 | if (!value) { |
| 116 | setToEntryState(after); |
| 117 | return success(); |
| 118 | } |
| 119 | |
| 120 | // If we cannot find the underlying value, we shouldn't just propagate the |
| 121 | // effects through, return the pessimistic state. |
| 122 | std::optional<Value> underlyingValue = |
| 123 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
| 124 | value, getUnderlyingValueFn: [&](Value value) { |
| 125 | return getOrCreateFor<UnderlyingValueLattice>( |
| 126 | dependent: getProgramPointAfter(op), anchor: value); |
| 127 | }); |
| 128 | |
| 129 | // If the underlying value is not yet known, don't propagate yet. |
| 130 | if (!underlyingValue) |
| 131 | return success(); |
| 132 | |
| 133 | underlyingValues.push_back(Elt: *underlyingValue); |
| 134 | } |
| 135 | |
| 136 | // Update the state when all underlying values are known. |
| 137 | ChangeResult result = after->join(lattice: before); |
| 138 | for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) { |
| 139 | // If the underlying value is known to be unknown, set to fixpoint state. |
| 140 | if (!value) { |
| 141 | setToEntryState(after); |
| 142 | return success(); |
| 143 | } |
| 144 | |
| 145 | // Nothing to do for reads. |
| 146 | if (isa<MemoryEffects::Read>(Val: effect.getEffect())) |
| 147 | continue; |
| 148 | |
| 149 | result |= after->set(value, op); |
| 150 | } |
| 151 | propagateIfChanged(state: after, changed: result); |
| 152 | return success(); |
| 153 | } |
| 154 | |
| 155 | void LastModifiedAnalysis::buildOperationEquivalentLatticeAnchor( |
| 156 | Operation *op) { |
| 157 | if (isMemoryEffectFree(op)) { |
| 158 | unionLatticeAnchors<LastModification>(anchor: getProgramPointBefore(op), |
| 159 | other: getProgramPointAfter(op)); |
| 160 | } |
| 161 | } |
| 162 | |
| 163 | void LastModifiedAnalysis::visitCallControlFlowTransfer( |
| 164 | CallOpInterface call, CallControlFlowAction action, |
| 165 | const LastModification &before, LastModification *after) { |
| 166 | if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) { |
| 167 | SmallVector<Value> underlyingValues; |
| 168 | underlyingValues.reserve(N: call->getNumOperands()); |
| 169 | for (Value operand : call.getArgOperands()) { |
| 170 | std::optional<Value> underlyingValue = |
| 171 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
| 172 | operand, [&](Value value) { |
| 173 | return getOrCreateFor<UnderlyingValueLattice>( |
| 174 | getProgramPointAfter(call.getOperation()), value); |
| 175 | }); |
| 176 | if (!underlyingValue) |
| 177 | return; |
| 178 | underlyingValues.push_back(*underlyingValue); |
| 179 | } |
| 180 | |
| 181 | ChangeResult result = after->join(lattice: before); |
| 182 | for (Value operand : underlyingValues) |
| 183 | result |= after->set(value: operand, op: call); |
| 184 | return propagateIfChanged(state: after, changed: result); |
| 185 | } |
| 186 | auto testCallAndStore = |
| 187 | dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); |
| 188 | if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && |
| 189 | testCallAndStore.getStoreBeforeCall()) || |
| 190 | (action == CallControlFlowAction::ExitCallee && |
| 191 | !testCallAndStore.getStoreBeforeCall()))) { |
| 192 | (void)visitOperation(op: call, before, after); |
| 193 | return; |
| 194 | } |
| 195 | AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer( |
| 196 | call: call, action, before, after); |
| 197 | } |
| 198 | |
| 199 | void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer( |
| 200 | RegionBranchOpInterface branch, std::optional<unsigned> regionFrom, |
| 201 | std::optional<unsigned> regionTo, const LastModification &before, |
| 202 | LastModification *after) { |
| 203 | auto defaultHandling = [&]() { |
| 204 | AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer( |
| 205 | branch: branch, regionFrom, regionTo, before, after); |
| 206 | }; |
| 207 | TypeSwitch<Operation *>(branch.getOperation()) |
| 208 | .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>( |
| 209 | [=](auto storeWithRegion) { |
| 210 | if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) || |
| 211 | (!regionFrom && storeWithRegion.getStoreBeforeRegion())) |
| 212 | (void)visitOperation(branch, before, after); |
| 213 | defaultHandling(); |
| 214 | }) |
| 215 | .Default([=](auto) { defaultHandling(); }); |
| 216 | } |
| 217 | |
| 218 | namespace { |
| 219 | struct TestLastModifiedPass |
| 220 | : public PassWrapper<TestLastModifiedPass, OperationPass<>> { |
| 221 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass) |
| 222 | |
| 223 | TestLastModifiedPass() = default; |
| 224 | TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) { |
| 225 | interprocedural = other.interprocedural; |
| 226 | assumeFuncWrites = other.assumeFuncWrites; |
| 227 | } |
| 228 | |
| 229 | StringRef getArgument() const override { return "test-last-modified" ; } |
| 230 | |
| 231 | Option<bool> interprocedural{ |
| 232 | *this, "interprocedural" , llvm::cl::init(Val: true), |
| 233 | llvm::cl::desc("perform interprocedural analysis" )}; |
| 234 | Option<bool> assumeFuncWrites{ |
| 235 | *this, "assume-func-writes" , llvm::cl::init(Val: false), |
| 236 | llvm::cl::desc( |
| 237 | "assume external functions have write effect on all arguments" )}; |
| 238 | |
| 239 | void runOnOperation() override { |
| 240 | Operation *op = getOperation(); |
| 241 | |
| 242 | DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); |
| 243 | solver.load<DeadCodeAnalysis>(); |
| 244 | solver.load<SparseConstantPropagation>(); |
| 245 | solver.load<LastModifiedAnalysis>(args&: assumeFuncWrites); |
| 246 | solver.load<UnderlyingValueAnalysis>(); |
| 247 | if (failed(Result: solver.initializeAndRun(top: op))) |
| 248 | return signalPassFailure(); |
| 249 | |
| 250 | raw_ostream &os = llvm::errs(); |
| 251 | |
| 252 | // Note that if the underlying value could not be computed or is unknown, we |
| 253 | // conservatively treat the result also unknown. |
| 254 | op->walk(callback: [&](Operation *op) { |
| 255 | auto tag = op->getAttrOfType<StringAttr>("tag" ); |
| 256 | if (!tag) |
| 257 | return; |
| 258 | os << "test_tag: " << tag.getValue() << ":\n" ; |
| 259 | const LastModification *lastMods = |
| 260 | solver.lookupState<LastModification>(anchor: solver.getProgramPointAfter(op)); |
| 261 | assert(lastMods && "expected a dense lattice" ); |
| 262 | for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) { |
| 263 | os << " operand #" << index << "\n" ; |
| 264 | std::optional<Value> underlyingValue = |
| 265 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
| 266 | value: operand, getUnderlyingValueFn: [&](Value value) { |
| 267 | return solver.lookupState<UnderlyingValueLattice>(anchor: value); |
| 268 | }); |
| 269 | if (!underlyingValue) { |
| 270 | os << " - <unknown>\n" ; |
| 271 | continue; |
| 272 | } |
| 273 | Value value = *underlyingValue; |
| 274 | assert(value && "expected an underlying value" ); |
| 275 | if (const AdjacentAccess *lastMod = |
| 276 | lastMods->getAdjacentAccess(value)) { |
| 277 | if (!lastMod->isKnown()) { |
| 278 | os << " - <unknown>\n" ; |
| 279 | } else { |
| 280 | for (Operation *lastModifier : lastMod->get()) { |
| 281 | if (auto tagName = |
| 282 | lastModifier->getAttrOfType<StringAttr>("tag_name" )) { |
| 283 | os << " - " << tagName.getValue() << "\n" ; |
| 284 | } else { |
| 285 | os << " - " << lastModifier->getName() << "\n" ; |
| 286 | } |
| 287 | } |
| 288 | } |
| 289 | } else { |
| 290 | os << " - <unknown>\n" ; |
| 291 | } |
| 292 | } |
| 293 | }); |
| 294 | } |
| 295 | }; |
| 296 | } // end anonymous namespace |
| 297 | |
| 298 | namespace mlir { |
| 299 | namespace test { |
| 300 | void registerTestLastModifiedPass() { |
| 301 | PassRegistration<TestLastModifiedPass>(); |
| 302 | } |
| 303 | } // end namespace test |
| 304 | } // end namespace mlir |
| 305 | |