| 1 | //===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Test pass for backward dense dataflow analysis. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "TestDenseDataFlowAnalysis.h" |
| 14 | #include "TestDialect.h" |
| 15 | #include "TestOps.h" |
| 16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
| 17 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
| 18 | #include "mlir/Analysis/DataFlow/DenseAnalysis.h" |
| 19 | #include "mlir/Analysis/DataFlowFramework.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/IR/SymbolTable.h" |
| 22 | #include "mlir/Interfaces/CallInterfaces.h" |
| 23 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| 24 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 25 | #include "mlir/Pass/Pass.h" |
| 26 | #include "mlir/Support/TypeID.h" |
| 27 | #include "llvm/Support/raw_ostream.h" |
| 28 | |
| 29 | using namespace mlir; |
| 30 | using namespace mlir::dataflow; |
| 31 | using namespace mlir::dataflow::test; |
| 32 | |
| 33 | namespace { |
| 34 | |
| 35 | class NextAccess : public AbstractDenseLattice, public AccessLatticeBase { |
| 36 | public: |
| 37 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess) |
| 38 | |
| 39 | using dataflow::AbstractDenseLattice::AbstractDenseLattice; |
| 40 | |
| 41 | ChangeResult meet(const AbstractDenseLattice &lattice) override { |
| 42 | return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>( |
| 43 | static_cast<const NextAccess &>(lattice))); |
| 44 | } |
| 45 | |
| 46 | void print(raw_ostream &os) const override { |
| 47 | return AccessLatticeBase::print(os); |
| 48 | } |
| 49 | }; |
| 50 | |
| 51 | class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> { |
| 52 | public: |
| 53 | NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, |
| 54 | bool assumeFuncReads = false) |
| 55 | : DenseBackwardDataFlowAnalysis(solver, symbolTable), |
| 56 | assumeFuncReads(assumeFuncReads) {} |
| 57 | |
| 58 | LogicalResult visitOperation(Operation *op, const NextAccess &after, |
| 59 | NextAccess *before) override; |
| 60 | |
| 61 | void visitCallControlFlowTransfer(CallOpInterface call, |
| 62 | CallControlFlowAction action, |
| 63 | const NextAccess &after, |
| 64 | NextAccess *before) override; |
| 65 | |
| 66 | void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, |
| 67 | RegionBranchPoint regionFrom, |
| 68 | RegionBranchPoint regionTo, |
| 69 | const NextAccess &after, |
| 70 | NextAccess *before) override; |
| 71 | |
| 72 | // TODO: this isn't ideal for the analysis. When there is no next access, it |
| 73 | // means "we don't know what the next access is" rather than "there is no next |
| 74 | // access". But it's unclear how to differentiate the two cases... |
| 75 | void setToExitState(NextAccess *lattice) override { |
| 76 | propagateIfChanged(state: lattice, changed: lattice->setKnownToUnknown()); |
| 77 | } |
| 78 | |
| 79 | /// Visit an operation. If this analysis can confirm that lattice content |
| 80 | /// of lattice anchors around operation are necessarily identical, join |
| 81 | /// them into the same equivalent class. |
| 82 | void buildOperationEquivalentLatticeAnchor(Operation *op) override; |
| 83 | |
| 84 | const bool assumeFuncReads; |
| 85 | }; |
| 86 | } // namespace |
| 87 | |
| 88 | LogicalResult NextAccessAnalysis::visitOperation(Operation *op, |
| 89 | const NextAccess &after, |
| 90 | NextAccess *before) { |
| 91 | auto memory = dyn_cast<MemoryEffectOpInterface>(op); |
| 92 | // If we can't reason about the memory effects, conservatively assume we can't |
| 93 | // say anything about the next access. |
| 94 | if (!memory) { |
| 95 | setToExitState(before); |
| 96 | return success(); |
| 97 | } |
| 98 | |
| 99 | SmallVector<MemoryEffects::EffectInstance> effects; |
| 100 | memory.getEffects(effects); |
| 101 | |
| 102 | // First, check if all underlying values are already known. Otherwise, avoid |
| 103 | // propagating and stay in the "undefined" state to avoid incorrectly |
| 104 | // propagating values that may be overwritten later on as that could be |
| 105 | // problematic for convergence based on monotonicity of lattice updates. |
| 106 | SmallVector<Value> underlyingValues; |
| 107 | underlyingValues.reserve(N: effects.size()); |
| 108 | for (const MemoryEffects::EffectInstance &effect : effects) { |
| 109 | Value value = effect.getValue(); |
| 110 | |
| 111 | // Effects with unspecified value are treated conservatively and we cannot |
| 112 | // assume anything about the next access. |
| 113 | if (!value) { |
| 114 | setToExitState(before); |
| 115 | return success(); |
| 116 | } |
| 117 | |
| 118 | // If cannot find the most underlying value, we cannot assume anything about |
| 119 | // the next accesses. |
| 120 | std::optional<Value> underlyingValue = |
| 121 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
| 122 | value, getUnderlyingValueFn: [&](Value value) { |
| 123 | return getOrCreateFor<UnderlyingValueLattice>( |
| 124 | dependent: getProgramPointBefore(op), anchor: value); |
| 125 | }); |
| 126 | |
| 127 | // If the underlying value is not known yet, don't propagate. |
| 128 | if (!underlyingValue) |
| 129 | return success(); |
| 130 | |
| 131 | underlyingValues.push_back(Elt: *underlyingValue); |
| 132 | } |
| 133 | |
| 134 | // Update the state if all underlying values are known. |
| 135 | ChangeResult result = before->meet(lattice: after); |
| 136 | for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) { |
| 137 | // If the underlying value is known to be unknown, set to fixpoint. |
| 138 | if (!value) { |
| 139 | setToExitState(before); |
| 140 | return success(); |
| 141 | } |
| 142 | |
| 143 | result |= before->set(value, op); |
| 144 | } |
| 145 | propagateIfChanged(state: before, changed: result); |
| 146 | return success(); |
| 147 | } |
| 148 | |
| 149 | void NextAccessAnalysis::buildOperationEquivalentLatticeAnchor(Operation *op) { |
| 150 | if (isMemoryEffectFree(op)) { |
| 151 | unionLatticeAnchors<NextAccess>(anchor: getProgramPointBefore(op), |
| 152 | other: getProgramPointAfter(op)); |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | void NextAccessAnalysis::visitCallControlFlowTransfer( |
| 157 | CallOpInterface call, CallControlFlowAction action, const NextAccess &after, |
| 158 | NextAccess *before) { |
| 159 | if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) { |
| 160 | SmallVector<Value> underlyingValues; |
| 161 | underlyingValues.reserve(N: call->getNumOperands()); |
| 162 | for (Value operand : call.getArgOperands()) { |
| 163 | std::optional<Value> underlyingValue = |
| 164 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
| 165 | operand, [&](Value value) { |
| 166 | return getOrCreateFor<UnderlyingValueLattice>( |
| 167 | getProgramPointBefore(call.getOperation()), value); |
| 168 | }); |
| 169 | if (!underlyingValue) |
| 170 | return; |
| 171 | underlyingValues.push_back(*underlyingValue); |
| 172 | } |
| 173 | |
| 174 | ChangeResult result = before->meet(lattice: after); |
| 175 | for (Value operand : underlyingValues) { |
| 176 | result |= before->set(value: operand, op: call); |
| 177 | } |
| 178 | return propagateIfChanged(state: before, changed: result); |
| 179 | } |
| 180 | auto testCallAndStore = |
| 181 | dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); |
| 182 | if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && |
| 183 | testCallAndStore.getStoreBeforeCall()) || |
| 184 | (action == CallControlFlowAction::ExitCallee && |
| 185 | !testCallAndStore.getStoreBeforeCall()))) { |
| 186 | (void)visitOperation(op: call, after, before); |
| 187 | } else { |
| 188 | AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer( |
| 189 | call: call, action, after, before); |
| 190 | } |
| 191 | } |
| 192 | |
| 193 | void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( |
| 194 | RegionBranchOpInterface branch, RegionBranchPoint regionFrom, |
| 195 | RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { |
| 196 | auto testStoreWithARegion = |
| 197 | dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); |
| 198 | |
| 199 | if (testStoreWithARegion && |
| 200 | ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) || |
| 201 | (regionFrom.isParent() && |
| 202 | testStoreWithARegion.getStoreBeforeRegion()))) { |
| 203 | (void)visitOperation(op: branch, after: static_cast<const NextAccess &>(after), |
| 204 | before: static_cast<NextAccess *>(before)); |
| 205 | } else { |
| 206 | propagateIfChanged(state: before, changed: before->meet(lattice: after)); |
| 207 | } |
| 208 | } |
| 209 | |
| 210 | namespace { |
| 211 | struct TestNextAccessPass |
| 212 | : public PassWrapper<TestNextAccessPass, OperationPass<>> { |
| 213 | TestNextAccessPass() = default; |
| 214 | TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) { |
| 215 | interprocedural = other.interprocedural; |
| 216 | assumeFuncReads = other.assumeFuncReads; |
| 217 | } |
| 218 | |
| 219 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass) |
| 220 | |
| 221 | StringRef getArgument() const override { return "test-next-access" ; } |
| 222 | |
| 223 | Option<bool> interprocedural{ |
| 224 | *this, "interprocedural" , llvm::cl::init(Val: true), |
| 225 | llvm::cl::desc("perform interprocedural analysis" )}; |
| 226 | Option<bool> assumeFuncReads{ |
| 227 | *this, "assume-func-reads" , llvm::cl::init(Val: false), |
| 228 | llvm::cl::desc( |
| 229 | "assume external functions have read effect on all arguments" )}; |
| 230 | |
| 231 | static constexpr llvm::StringLiteral kTagAttrName = "name" ; |
| 232 | static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access" ; |
| 233 | static constexpr llvm::StringLiteral kAtEntryPointAttrName = |
| 234 | "next_at_entry_point" ; |
| 235 | |
| 236 | static Attribute makeNextAccessAttribute(Operation *op, |
| 237 | const DataFlowSolver &solver, |
| 238 | const NextAccess *nextAccess) { |
| 239 | if (!nextAccess) |
| 240 | return StringAttr::get(op->getContext(), "not computed" ); |
| 241 | |
| 242 | // Note that if the underlying value could not be computed or is unknown, we |
| 243 | // conservatively treat the result also unknown. |
| 244 | SmallVector<Attribute> attrs; |
| 245 | for (Value operand : op->getOperands()) { |
| 246 | std::optional<Value> underlyingValue = |
| 247 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
| 248 | value: operand, getUnderlyingValueFn: [&](Value value) { |
| 249 | return solver.lookupState<UnderlyingValueLattice>(anchor: value); |
| 250 | }); |
| 251 | if (!underlyingValue) { |
| 252 | attrs.push_back(StringAttr::get(op->getContext(), "unknown" )); |
| 253 | continue; |
| 254 | } |
| 255 | Value value = *underlyingValue; |
| 256 | const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value); |
| 257 | if (!nextAcc || !nextAcc->isKnown()) { |
| 258 | attrs.push_back(StringAttr::get(op->getContext(), "unknown" )); |
| 259 | continue; |
| 260 | } |
| 261 | |
| 262 | SmallVector<Attribute> innerAttrs; |
| 263 | innerAttrs.reserve(N: nextAcc->get().size()); |
| 264 | for (Operation *nextAccOp : nextAcc->get()) { |
| 265 | if (auto nextAccTag = |
| 266 | nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) { |
| 267 | innerAttrs.push_back(Elt: nextAccTag); |
| 268 | continue; |
| 269 | } |
| 270 | std::string repr; |
| 271 | llvm::raw_string_ostream os(repr); |
| 272 | nextAccOp->print(os); |
| 273 | innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); |
| 274 | } |
| 275 | attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); |
| 276 | } |
| 277 | return ArrayAttr::get(op->getContext(), attrs); |
| 278 | } |
| 279 | |
| 280 | void runOnOperation() override { |
| 281 | Operation *op = getOperation(); |
| 282 | SymbolTableCollection symbolTable; |
| 283 | |
| 284 | auto config = DataFlowConfig().setInterprocedural(interprocedural); |
| 285 | DataFlowSolver solver(config); |
| 286 | solver.load<DeadCodeAnalysis>(); |
| 287 | solver.load<NextAccessAnalysis>(args&: symbolTable, args&: assumeFuncReads); |
| 288 | solver.load<SparseConstantPropagation>(); |
| 289 | solver.load<UnderlyingValueAnalysis>(); |
| 290 | if (failed(Result: solver.initializeAndRun(top: op))) { |
| 291 | emitError(loc: op->getLoc(), message: "dataflow solver failed" ); |
| 292 | return signalPassFailure(); |
| 293 | } |
| 294 | op->walk(callback: [&](Operation *op) { |
| 295 | auto tag = op->getAttrOfType<StringAttr>(kTagAttrName); |
| 296 | if (!tag) |
| 297 | return; |
| 298 | |
| 299 | const NextAccess *nextAccess = |
| 300 | solver.lookupState<NextAccess>(anchor: solver.getProgramPointAfter(op)); |
| 301 | op->setAttr(name: kNextAccessAttrName, |
| 302 | value: makeNextAccessAttribute(op, solver, nextAccess)); |
| 303 | |
| 304 | auto iface = dyn_cast<RegionBranchOpInterface>(op); |
| 305 | if (!iface) |
| 306 | return; |
| 307 | |
| 308 | SmallVector<Attribute> entryPointNextAccess; |
| 309 | SmallVector<RegionSuccessor> regionSuccessors; |
| 310 | iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors); |
| 311 | for (const RegionSuccessor &successor : regionSuccessors) { |
| 312 | if (!successor.getSuccessor() || successor.getSuccessor()->empty()) |
| 313 | continue; |
| 314 | Block &successorBlock = successor.getSuccessor()->front(); |
| 315 | ProgramPoint *successorPoint = |
| 316 | solver.getProgramPointBefore(block: &successorBlock); |
| 317 | entryPointNextAccess.push_back(Elt: makeNextAccessAttribute( |
| 318 | op, solver, nextAccess: solver.lookupState<NextAccess>(anchor: successorPoint))); |
| 319 | } |
| 320 | op->setAttr(kAtEntryPointAttrName, |
| 321 | ArrayAttr::get(op->getContext(), entryPointNextAccess)); |
| 322 | }); |
| 323 | } |
| 324 | }; |
| 325 | } // namespace |
| 326 | |
| 327 | namespace mlir::test { |
| 328 | void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); } |
| 329 | } // namespace mlir::test |
| 330 | |