| 1 | //===- CheckUses.cpp - Expensive transform value validity checks ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines a pass that performs expensive opt-in checks for Transform |
| 10 | // dialect values being potentially used after they have been consumed. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Transform/Transforms/Passes.h" |
| 15 | |
| 16 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 17 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 18 | #include "llvm/ADT/SetOperations.h" |
| 19 | |
| 20 | namespace mlir { |
| 21 | namespace transform { |
| 22 | #define GEN_PASS_DEF_CHECKUSESPASS |
| 23 | #include "mlir/Dialect/Transform/Transforms/Passes.h.inc" |
| 24 | } // namespace transform |
| 25 | } // namespace mlir |
| 26 | |
| 27 | using namespace mlir; |
| 28 | |
| 29 | namespace { |
| 30 | |
| 31 | /// Returns a reference to a cached set of blocks that are reachable from the |
| 32 | /// given block via edges computed by the `getNextNodes` function. For example, |
| 33 | /// if `getNextNodes` returns successors of a block, this will return the set of |
| 34 | /// reachable blocks; if it returns predecessors of a block, this will return |
| 35 | /// the set of blocks from which the given block can be reached. The block is |
| 36 | /// considered reachable form itself only if there is a cycle. |
| 37 | template <typename FnTy> |
| 38 | const llvm::SmallPtrSet<Block *, 4> & |
| 39 | getReachableImpl(Block *block, FnTy getNextNodes, |
| 40 | DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> &cache) { |
| 41 | auto [it, inserted] = cache.try_emplace(Key: block); |
| 42 | if (!inserted) |
| 43 | return it->getSecond(); |
| 44 | |
| 45 | llvm::SmallPtrSet<Block *, 4> &reachable = it->second; |
| 46 | SmallVector<Block *> worklist; |
| 47 | worklist.push_back(Elt: block); |
| 48 | while (!worklist.empty()) { |
| 49 | Block *current = worklist.pop_back_val(); |
| 50 | for (Block *predecessor : getNextNodes(current)) { |
| 51 | // The block is reachable from its transitive predecessors. Only add |
| 52 | // them to the worklist if they weren't already visited. |
| 53 | if (reachable.insert(Ptr: predecessor).second) |
| 54 | worklist.push_back(Elt: predecessor); |
| 55 | } |
| 56 | } |
| 57 | return reachable; |
| 58 | } |
| 59 | |
| 60 | /// An analysis that identifies whether a value allocated by a Transform op may |
| 61 | /// be used by another such op after it may have been freed by a third op on |
| 62 | /// some control flow path. This is conceptually similar to a data flow |
| 63 | /// analysis, but relies on side effects related to particular values that |
| 64 | /// currently cannot be modeled by the MLIR data flow analysis framework (also, |
| 65 | /// the lattice element would be rather expensive as it would need to include |
| 66 | /// live and/or freed values for each operation). |
| 67 | /// |
| 68 | /// This analysis is conservatively pessimisic: it will consider that a value |
| 69 | /// may be freed if it is freed on any possible control flow path between its |
| 70 | /// allocation and a relevant use, even if the control never actually flows |
| 71 | /// through the operation that frees the value. It also does not differentiate |
| 72 | /// between may- (freed on at least one control flow path) and must-free (freed |
| 73 | /// on all possible control flow paths) because it would require expensive graph |
| 74 | /// algorithms. |
| 75 | /// |
| 76 | /// It is intended as an additional non-blocking verification or debugging aid |
| 77 | /// for ops in the Transform dialect. It leverages the requirement for Transform |
| 78 | /// dialect ops to implement the MemoryEffectsOpInterface, and expects the |
| 79 | /// values in the Transform IR to have an allocation effect on the |
| 80 | /// TransformMappingResource when defined. |
| 81 | class TransformOpMemFreeAnalysis { |
| 82 | public: |
| 83 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformOpMemFreeAnalysis) |
| 84 | |
| 85 | /// Computes the analysis for Transform ops nested in the given operation. |
| 86 | explicit TransformOpMemFreeAnalysis(Operation *root) { |
| 87 | root->walk(callback: [&](Operation *op) { |
| 88 | if (isa<transform::TransformOpInterface>(Val: op)) { |
| 89 | collectFreedValues(root: op); |
| 90 | return WalkResult::skip(); |
| 91 | } |
| 92 | return WalkResult::advance(); |
| 93 | }); |
| 94 | } |
| 95 | |
| 96 | /// A list of operations that may be deleting a value. Non-empty list |
| 97 | /// contextually converts to boolean "true" value. |
| 98 | class PotentialDeleters { |
| 99 | public: |
| 100 | /// Creates an empty list that corresponds to the value being live. |
| 101 | static PotentialDeleters live() { return PotentialDeleters({}); } |
| 102 | |
| 103 | /// Creates a list from the operations that may be deleting the value. |
| 104 | static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) { |
| 105 | return PotentialDeleters(deleters); |
| 106 | } |
| 107 | |
| 108 | /// Converts to "true" if there are operations that may be deleting the |
| 109 | /// value. |
| 110 | explicit operator bool() const { return !deleters.empty(); } |
| 111 | |
| 112 | /// Concatenates the lists of operations that may be deleting the value. The |
| 113 | /// value is known to be live if the reuslting list is still empty. |
| 114 | PotentialDeleters &operator|=(const PotentialDeleters &other) { |
| 115 | llvm::append_range(C&: deleters, R: other.deleters); |
| 116 | return *this; |
| 117 | } |
| 118 | |
| 119 | /// Returns the list of ops that may be deleting the value. |
| 120 | ArrayRef<Operation *> getOps() const { return deleters; } |
| 121 | |
| 122 | private: |
| 123 | /// Constructs the list from the given operations. |
| 124 | explicit PotentialDeleters(ArrayRef<Operation *> ops) { |
| 125 | llvm::append_range(C&: deleters, R&: ops); |
| 126 | } |
| 127 | |
| 128 | /// The list of operations that may be deleting the value. |
| 129 | SmallVector<Operation *> deleters; |
| 130 | }; |
| 131 | |
| 132 | /// Returns the list of operations that may be deleting the operand value on |
| 133 | /// any control flow path between the definition of the value and its use as |
| 134 | /// the given operand. For the purposes of this analysis, the value is |
| 135 | /// considered to be allocated at its definition point and never re-allocated. |
| 136 | PotentialDeleters isUseLive(OpOperand &operand) { |
| 137 | const llvm::SmallPtrSet<Operation *, 2> &deleters = freedBy[operand.get()]; |
| 138 | if (deleters.empty()) |
| 139 | return live(); |
| 140 | |
| 141 | #ifndef NDEBUG |
| 142 | // Check that the definition point actually allocates the value. If the |
| 143 | // definition is a block argument, it may be just forwarding the operand of |
| 144 | // the parent op without doing a new allocation, allow that. We currently |
| 145 | // don't have the capability to analyze region-based control flow here. |
| 146 | // |
| 147 | // TODO: when this ported to the dataflow analysis infra, we should have |
| 148 | // proper support for region-based control flow. |
| 149 | Operation *valueSource = |
| 150 | isa<OpResult>(operand.get()) |
| 151 | ? operand.get().getDefiningOp() |
| 152 | : operand.get().getParentBlock()->getParentOp(); |
| 153 | auto iface = cast<MemoryEffectOpInterface>(valueSource); |
| 154 | SmallVector<MemoryEffects::EffectInstance> instances; |
| 155 | iface.getEffectsOnResource(transform::TransformMappingResource::get(), |
| 156 | instances); |
| 157 | assert((isa<BlockArgument>(operand.get()) || |
| 158 | hasEffect<MemoryEffects::Allocate>(instances, operand.get())) && |
| 159 | "expected the op defining the value to have an allocation effect " |
| 160 | "on it" ); |
| 161 | #endif |
| 162 | |
| 163 | // Collect ancestors of the use operation. |
| 164 | Block *defBlock = operand.get().getParentBlock(); |
| 165 | SmallVector<Operation *> ancestors; |
| 166 | Operation *ancestor = operand.getOwner(); |
| 167 | do { |
| 168 | ancestors.push_back(Elt: ancestor); |
| 169 | if (ancestor->getParentRegion() == defBlock->getParent()) |
| 170 | break; |
| 171 | ancestor = ancestor->getParentOp(); |
| 172 | } while (true); |
| 173 | std::reverse(first: ancestors.begin(), last: ancestors.end()); |
| 174 | |
| 175 | // Consider the control flow from the definition point of the value to its |
| 176 | // use point. If the use is located in some nested region, consider the path |
| 177 | // from the entry block of the region to the use. |
| 178 | for (Operation *ancestor : ancestors) { |
| 179 | // The block should be considered partially if it is the block that |
| 180 | // contains the definition (allocation) of the value being used, and the |
| 181 | // value is defined in the middle of the block, i.e., is not a block |
| 182 | // argument. |
| 183 | bool isOutermost = ancestor == ancestors.front(); |
| 184 | bool isFromBlockPartial = isOutermost && isa<OpResult>(Val: operand.get()); |
| 185 | |
| 186 | // Check if the value may be freed by operations between its definition |
| 187 | // (allocation) point in its block and the terminator of the block or the |
| 188 | // ancestor of the use if it is located in the same block. This is only |
| 189 | // done for partial blocks here, full blocks will be considered below |
| 190 | // similarly to other blocks. |
| 191 | if (isFromBlockPartial) { |
| 192 | bool defUseSameBlock = ancestor->getBlock() == defBlock; |
| 193 | // Consider all ops from the def to its block terminator, except the |
| 194 | // when the use is in the same block, in which case only consider the |
| 195 | // ops until the user. |
| 196 | if (PotentialDeleters potentialDeleters = isFreedInBlockAfter( |
| 197 | root: operand.get().getDefiningOp(), value: operand.get(), |
| 198 | before: defUseSameBlock ? ancestor : nullptr)) |
| 199 | return potentialDeleters; |
| 200 | } |
| 201 | |
| 202 | // Check if the value may be freed by opeations preceding the ancestor in |
| 203 | // its block. Skip the check for partial blocks that contain both the |
| 204 | // definition and the use point, as this has been already checked above. |
| 205 | if (!isFromBlockPartial || ancestor->getBlock() != defBlock) { |
| 206 | if (PotentialDeleters potentialDeleters = |
| 207 | isFreedInBlockBefore(root: ancestor, value: operand.get())) |
| 208 | return potentialDeleters; |
| 209 | } |
| 210 | |
| 211 | // Check if the value may be freed by operations in any of the blocks |
| 212 | // between the definition point (in the outermost region) or the entry |
| 213 | // block of the region (in other regions) and the operand or its ancestor |
| 214 | // in the region. This includes the entire "form" block if (1) the block |
| 215 | // has not been considered as partial above and (2) the block can be |
| 216 | // reached again through some control-flow loop. This includes the entire |
| 217 | // "to" block if it can be reached form itself through some control-flow |
| 218 | // cycle, regardless of whether it has been visited before. |
| 219 | Block *ancestorBlock = ancestor->getBlock(); |
| 220 | Block *from = |
| 221 | isOutermost ? defBlock : &ancestorBlock->getParent()->front(); |
| 222 | if (PotentialDeleters potentialDeleters = |
| 223 | isMaybeFreedOnPaths(from, to: ancestorBlock, value: operand.get(), |
| 224 | /*alwaysIncludeFrom=*/!isFromBlockPartial)) |
| 225 | return potentialDeleters; |
| 226 | } |
| 227 | return live(); |
| 228 | } |
| 229 | |
| 230 | private: |
| 231 | /// Make PotentialDeleters constructors available with shorter names. |
| 232 | static PotentialDeleters maybeFreed(ArrayRef<Operation *> deleters) { |
| 233 | return PotentialDeleters::maybeFreed(deleters); |
| 234 | } |
| 235 | static PotentialDeleters live() { return PotentialDeleters::live(); } |
| 236 | |
| 237 | /// Returns the list of operations that may be deleting the given value betwen |
| 238 | /// the first and last operations, non-inclusive. `getNext` indicates the |
| 239 | /// direction of the traversal. |
| 240 | PotentialDeleters |
| 241 | isFreedBetween(Value value, Operation *first, Operation *last, |
| 242 | llvm::function_ref<Operation *(Operation *)> getNext) const { |
| 243 | auto it = freedBy.find(Val: value); |
| 244 | if (it == freedBy.end()) |
| 245 | return live(); |
| 246 | const llvm::SmallPtrSet<Operation *, 2> &deleters = it->getSecond(); |
| 247 | for (Operation *op = getNext(first); op != last; op = getNext(op)) { |
| 248 | if (deleters.contains(Ptr: op)) |
| 249 | return maybeFreed(deleters: op); |
| 250 | } |
| 251 | return live(); |
| 252 | } |
| 253 | |
| 254 | /// Returns the list of operations that may be deleting the given value |
| 255 | /// between `root` and `before` values. `root` is expected to be in the same |
| 256 | /// block as `before` and precede it. If `before` is null, consider all |
| 257 | /// operations until the end of the block including the terminator. |
| 258 | PotentialDeleters isFreedInBlockAfter(Operation *root, Value value, |
| 259 | Operation *before = nullptr) const { |
| 260 | return isFreedBetween(value, first: root, last: before, |
| 261 | getNext: [](Operation *op) { return op->getNextNode(); }); |
| 262 | } |
| 263 | |
| 264 | /// Returns the list of operations that may be deleting the given value |
| 265 | /// between the entry of the block and the `root` operation. |
| 266 | PotentialDeleters isFreedInBlockBefore(Operation *root, Value value) const { |
| 267 | return isFreedBetween(value, first: root, last: nullptr, |
| 268 | getNext: [](Operation *op) { return op->getPrevNode(); }); |
| 269 | } |
| 270 | |
| 271 | /// Returns the list of operations that may be deleting the given value on |
| 272 | /// any of the control flow paths between the "form" and the "to" block. The |
| 273 | /// operations from any block visited on any control flow path are |
| 274 | /// consdiered. The "from" block is considered if there is a control flow |
| 275 | /// cycle going through it, i.e., if there is a possibility that all |
| 276 | /// operations in this block are visited or if the `alwaysIncludeFrom` flag is |
| 277 | /// set. The "to" block is considered only if there is a control flow cycle |
| 278 | /// going through it. |
| 279 | PotentialDeleters isMaybeFreedOnPaths(Block *from, Block *to, Value value, |
| 280 | bool alwaysIncludeFrom) { |
| 281 | // Find all blocks that lie on any path between "from" and "to", i.e., the |
| 282 | // intersection of blocks reachable from "from" and blocks from which "to" |
| 283 | // is rechable. |
| 284 | const llvm::SmallPtrSet<Block *, 4> &sources = getReachableFrom(block: to); |
| 285 | if (!sources.contains(Ptr: from)) |
| 286 | return live(); |
| 287 | |
| 288 | llvm::SmallPtrSet<Block *, 4> reachable(getReachable(block: from)); |
| 289 | llvm::set_intersect(S1&: reachable, S2: sources); |
| 290 | |
| 291 | // If requested, include the "from" block that may not be present in the set |
| 292 | // of visited blocks when there is no cycle going through it. |
| 293 | if (alwaysIncludeFrom) |
| 294 | reachable.insert(Ptr: from); |
| 295 | |
| 296 | // Join potential deleters from all blocks as we don't know here which of |
| 297 | // the paths through the control flow is taken. |
| 298 | PotentialDeleters potentialDeleters = live(); |
| 299 | for (Block *block : reachable) { |
| 300 | for (Operation &op : *block) { |
| 301 | if (freedBy[value].count(Ptr: &op)) |
| 302 | potentialDeleters |= maybeFreed(deleters: &op); |
| 303 | } |
| 304 | } |
| 305 | return potentialDeleters; |
| 306 | } |
| 307 | |
| 308 | /// Popualtes `reachable` with the set of blocks that are rechable from the |
| 309 | /// given block. A block is considered reachable from itself if there is a |
| 310 | /// cycle in the control-flow graph that invovles the block. |
| 311 | const llvm::SmallPtrSet<Block *, 4> &getReachable(Block *block) { |
| 312 | return getReachableImpl( |
| 313 | block, getNextNodes: [](Block *b) { return b->getSuccessors(); }, cache&: reachableCache); |
| 314 | } |
| 315 | |
| 316 | /// Populates `sources` with the set of blocks from which the given block is |
| 317 | /// reachable. |
| 318 | const llvm::SmallPtrSet<Block *, 4> &getReachableFrom(Block *block) { |
| 319 | return getReachableImpl( |
| 320 | block, getNextNodes: [](Block *b) { return b->getPredecessors(); }, |
| 321 | cache&: reachableFromCache); |
| 322 | } |
| 323 | |
| 324 | /// Returns true of `instances` contains an effect of `EffectTy` on `value`. |
| 325 | template <typename EffectTy> |
| 326 | static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> instances, |
| 327 | Value value) { |
| 328 | return llvm::any_of(instances, |
| 329 | [&](const MemoryEffects::EffectInstance &instance) { |
| 330 | return instance.getValue() == value && |
| 331 | isa<EffectTy>(instance.getEffect()); |
| 332 | }); |
| 333 | } |
| 334 | |
| 335 | /// Records the values that are being freed by an operation or any of its |
| 336 | /// children in `freedBy`. |
| 337 | void collectFreedValues(Operation *root) { |
| 338 | SmallVector<MemoryEffects::EffectInstance> instances; |
| 339 | root->walk(callback: [&](Operation *child) { |
| 340 | if (isa<transform::PatternDescriptorOpInterface>(Val: child)) |
| 341 | return; |
| 342 | // TODO: extend this to conservatively handle operations with undeclared |
| 343 | // side effects as maybe freeing the operands. |
| 344 | auto iface = cast<MemoryEffectOpInterface>(Val: child); |
| 345 | instances.clear(); |
| 346 | iface.getEffectsOnResource(resource: transform::TransformMappingResource::get(), |
| 347 | effects&: instances); |
| 348 | for (Value operand : child->getOperands()) { |
| 349 | if (hasEffect<MemoryEffects::Free>(instances, value: operand)) { |
| 350 | // All parents of the operation that frees a value should be |
| 351 | // considered as potentially freeing the value as well. |
| 352 | // |
| 353 | // TODO: differentiate between must-free/may-free as well as between |
| 354 | // this op having the effect and children having the effect. This may |
| 355 | // require some analysis of all control flow paths through the nested |
| 356 | // regions as well as a mechanism to separate proper side effects from |
| 357 | // those obtained by nesting. |
| 358 | Operation *parent = child; |
| 359 | do { |
| 360 | freedBy[operand].insert(Ptr: parent); |
| 361 | if (parent == root) |
| 362 | break; |
| 363 | parent = parent->getParentOp(); |
| 364 | } while (true); |
| 365 | } |
| 366 | } |
| 367 | }); |
| 368 | } |
| 369 | |
| 370 | /// The mapping from a value to operations that have a Free memory effect on |
| 371 | /// the TransformMappingResource and associated with this value, or to |
| 372 | /// Transform operations transitively containing such operations. |
| 373 | DenseMap<Value, llvm::SmallPtrSet<Operation *, 2>> freedBy; |
| 374 | |
| 375 | /// Caches for sets of reachable blocks. |
| 376 | DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableCache; |
| 377 | DenseMap<Block *, llvm::SmallPtrSet<Block *, 4>> reachableFromCache; |
| 378 | }; |
| 379 | |
| 380 | //// A simple pass that warns about any use of a value by a transform operation |
| 381 | // that may be using the value after it has been freed. |
| 382 | class CheckUsesPass : public transform::impl::CheckUsesPassBase<CheckUsesPass> { |
| 383 | public: |
| 384 | void runOnOperation() override { |
| 385 | auto &analysis = getAnalysis<TransformOpMemFreeAnalysis>(); |
| 386 | |
| 387 | getOperation()->walk(callback: [&](Operation *child) { |
| 388 | for (OpOperand &operand : child->getOpOperands()) { |
| 389 | TransformOpMemFreeAnalysis::PotentialDeleters deleters = |
| 390 | analysis.isUseLive(operand); |
| 391 | if (!deleters) |
| 392 | continue; |
| 393 | |
| 394 | InFlightDiagnostic diag = child->emitWarning() |
| 395 | << "operand #" << operand.getOperandNumber() |
| 396 | << " may be used after free" ; |
| 397 | diag.attachNote(noteLoc: operand.get().getLoc()) << "allocated here" ; |
| 398 | for (Operation *d : deleters.getOps()) { |
| 399 | diag.attachNote(noteLoc: d->getLoc()) << "freed here" ; |
| 400 | } |
| 401 | } |
| 402 | }); |
| 403 | } |
| 404 | }; |
| 405 | |
| 406 | } // namespace |
| 407 | |
| 408 | |