| 1 | //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements Analysis functions specific to slicing in Function. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Analysis/SliceAnalysis.h" |
| 14 | #include "mlir/Analysis/TopologicalSortUtils.h" |
| 15 | #include "mlir/IR/Block.h" |
| 16 | #include "mlir/IR/Operation.h" |
| 17 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 18 | #include "mlir/Support/LLVM.h" |
| 19 | #include "llvm/ADT/STLExtras.h" |
| 20 | #include "llvm/ADT/SetVector.h" |
| 21 | #include "llvm/ADT/SmallPtrSet.h" |
| 22 | |
| 23 | /// |
| 24 | /// Implements Analysis functions specific to slicing in Function. |
| 25 | /// |
| 26 | |
| 27 | using namespace mlir; |
| 28 | |
| 29 | static void |
| 30 | getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice, |
| 31 | const SliceOptions::TransitiveFilter &filter = nullptr) { |
| 32 | if (!op) |
| 33 | return; |
| 34 | |
| 35 | // Evaluate whether we should keep this use. |
| 36 | // This is useful in particular to implement scoping; i.e. return the |
| 37 | // transitive forwardSlice in the current scope. |
| 38 | if (filter && !filter(op)) |
| 39 | return; |
| 40 | |
| 41 | for (Region ®ion : op->getRegions()) |
| 42 | for (Block &block : region) |
| 43 | for (Operation &blockOp : block) |
| 44 | if (forwardSlice->count(key: &blockOp) == 0) |
| 45 | getForwardSliceImpl(op: &blockOp, forwardSlice, filter); |
| 46 | for (Value result : op->getResults()) { |
| 47 | for (Operation *userOp : result.getUsers()) |
| 48 | if (forwardSlice->count(key: userOp) == 0) |
| 49 | getForwardSliceImpl(op: userOp, forwardSlice, filter); |
| 50 | } |
| 51 | |
| 52 | forwardSlice->insert(X: op); |
| 53 | } |
| 54 | |
| 55 | void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, |
| 56 | const ForwardSliceOptions &options) { |
| 57 | getForwardSliceImpl(op, forwardSlice, filter: options.filter); |
| 58 | if (!options.inclusive) { |
| 59 | // Don't insert the top level operation, we just queried on it and don't |
| 60 | // want it in the results. |
| 61 | forwardSlice->remove(X: op); |
| 62 | } |
| 63 | |
| 64 | // Reverse to get back the actual topological order. |
| 65 | // std::reverse does not work out of the box on SetVector and I want an |
| 66 | // in-place swap based thing (the real std::reverse, not the LLVM adapter). |
| 67 | SmallVector<Operation *, 0> v(forwardSlice->takeVector()); |
| 68 | forwardSlice->insert(Start: v.rbegin(), End: v.rend()); |
| 69 | } |
| 70 | |
| 71 | void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, |
| 72 | const SliceOptions &options) { |
| 73 | for (Operation *user : root.getUsers()) |
| 74 | getForwardSliceImpl(op: user, forwardSlice, filter: options.filter); |
| 75 | |
| 76 | // Reverse to get back the actual topological order. |
| 77 | // std::reverse does not work out of the box on SetVector and I want an |
| 78 | // in-place swap based thing (the real std::reverse, not the LLVM adapter). |
| 79 | SmallVector<Operation *, 0> v(forwardSlice->takeVector()); |
| 80 | forwardSlice->insert(Start: v.rbegin(), End: v.rend()); |
| 81 | } |
| 82 | |
| 83 | static LogicalResult getBackwardSliceImpl(Operation *op, |
| 84 | SetVector<Operation *> *backwardSlice, |
| 85 | const BackwardSliceOptions &options) { |
| 86 | if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
| 87 | return success(); |
| 88 | |
| 89 | // Evaluate whether we should keep this def. |
| 90 | // This is useful in particular to implement scoping; i.e. return the |
| 91 | // transitive backwardSlice in the current scope. |
| 92 | if (options.filter && !options.filter(op)) |
| 93 | return success(); |
| 94 | |
| 95 | auto processValue = [&](Value value) { |
| 96 | if (auto *definingOp = value.getDefiningOp()) { |
| 97 | if (backwardSlice->count(key: definingOp) == 0) |
| 98 | return getBackwardSliceImpl(op: definingOp, backwardSlice, options); |
| 99 | } else if (auto blockArg = dyn_cast<BlockArgument>(Val&: value)) { |
| 100 | if (options.omitBlockArguments) |
| 101 | return success(); |
| 102 | |
| 103 | Block *block = blockArg.getOwner(); |
| 104 | Operation *parentOp = block->getParentOp(); |
| 105 | // TODO: determine whether we want to recurse backward into the other |
| 106 | // blocks of parentOp, which are not technically backward unless they flow |
| 107 | // into us. For now, just bail. |
| 108 | if (parentOp && backwardSlice->count(key: parentOp) == 0) { |
| 109 | if (parentOp->getNumRegions() == 1 && |
| 110 | llvm::hasSingleElement(C&: parentOp->getRegion(index: 0).getBlocks())) { |
| 111 | return getBackwardSliceImpl(op: parentOp, backwardSlice, options); |
| 112 | } |
| 113 | } |
| 114 | } else { |
| 115 | return failure(); |
| 116 | } |
| 117 | return success(); |
| 118 | }; |
| 119 | |
| 120 | bool succeeded = true; |
| 121 | |
| 122 | if (!options.omitUsesFromAbove) { |
| 123 | llvm::for_each(Range: op->getRegions(), F: [&](Region ®ion) { |
| 124 | // Walk this region recursively to collect the regions that descend from |
| 125 | // this op's nested regions (inclusive). |
| 126 | SmallPtrSet<Region *, 4> descendents; |
| 127 | region.walk( |
| 128 | callback: [&](Region *childRegion) { descendents.insert(Ptr: childRegion); }); |
| 129 | region.walk(callback: [&](Operation *op) { |
| 130 | for (OpOperand &operand : op->getOpOperands()) { |
| 131 | if (!descendents.contains(Ptr: operand.get().getParentRegion())) |
| 132 | if (!processValue(operand.get()).succeeded()) { |
| 133 | return WalkResult::interrupt(); |
| 134 | } |
| 135 | } |
| 136 | return WalkResult::advance(); |
| 137 | }); |
| 138 | }); |
| 139 | } |
| 140 | llvm::for_each(Range: op->getOperands(), F: processValue); |
| 141 | |
| 142 | backwardSlice->insert(X: op); |
| 143 | return success(IsSuccess: succeeded); |
| 144 | } |
| 145 | |
| 146 | LogicalResult mlir::getBackwardSlice(Operation *op, |
| 147 | SetVector<Operation *> *backwardSlice, |
| 148 | const BackwardSliceOptions &options) { |
| 149 | LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options); |
| 150 | |
| 151 | if (!options.inclusive) { |
| 152 | // Don't insert the top level operation, we just queried on it and don't |
| 153 | // want it in the results. |
| 154 | backwardSlice->remove(X: op); |
| 155 | } |
| 156 | return result; |
| 157 | } |
| 158 | |
| 159 | LogicalResult mlir::getBackwardSlice(Value root, |
| 160 | SetVector<Operation *> *backwardSlice, |
| 161 | const BackwardSliceOptions &options) { |
| 162 | if (Operation *definingOp = root.getDefiningOp()) { |
| 163 | return getBackwardSlice(op: definingOp, backwardSlice, options); |
| 164 | } |
| 165 | Operation *bbAargOwner = cast<BlockArgument>(Val&: root).getOwner()->getParentOp(); |
| 166 | return getBackwardSlice(op: bbAargOwner, backwardSlice, options); |
| 167 | } |
| 168 | |
| 169 | SetVector<Operation *> |
| 170 | mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, |
| 171 | const ForwardSliceOptions &forwardSliceOptions) { |
| 172 | SetVector<Operation *> slice; |
| 173 | slice.insert(X: op); |
| 174 | |
| 175 | unsigned currentIndex = 0; |
| 176 | SetVector<Operation *> backwardSlice; |
| 177 | SetVector<Operation *> forwardSlice; |
| 178 | while (currentIndex != slice.size()) { |
| 179 | auto *currentOp = (slice)[currentIndex]; |
| 180 | // Compute and insert the backwardSlice starting from currentOp. |
| 181 | backwardSlice.clear(); |
| 182 | LogicalResult result = |
| 183 | getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions); |
| 184 | assert(result.succeeded()); |
| 185 | (void)result; |
| 186 | slice.insert_range(R&: backwardSlice); |
| 187 | |
| 188 | // Compute and insert the forwardSlice starting from currentOp. |
| 189 | forwardSlice.clear(); |
| 190 | getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions); |
| 191 | slice.insert_range(R&: forwardSlice); |
| 192 | ++currentIndex; |
| 193 | } |
| 194 | return topologicalSort(toSort: slice); |
| 195 | } |
| 196 | |
| 197 | /// Returns true if `value` (transitively) depends on iteration-carried values |
| 198 | /// of the given `ancestorOp`. |
| 199 | static bool dependsOnCarriedVals(Value value, |
| 200 | ArrayRef<BlockArgument> iterCarriedArgs, |
| 201 | Operation *ancestorOp) { |
| 202 | // Compute the backward slice of the value. |
| 203 | SetVector<Operation *> slice; |
| 204 | BackwardSliceOptions sliceOptions; |
| 205 | sliceOptions.filter = [&](Operation *op) { |
| 206 | return !ancestorOp->isAncestor(other: op); |
| 207 | }; |
| 208 | LogicalResult result = getBackwardSlice(root: value, backwardSlice: &slice, options: sliceOptions); |
| 209 | assert(result.succeeded()); |
| 210 | (void)result; |
| 211 | |
| 212 | // Check that none of the operands of the operations in the backward slice are |
| 213 | // loop iteration arguments, and neither is the value itself. |
| 214 | SmallPtrSet<Value, 8> iterCarriedValSet(llvm::from_range, iterCarriedArgs); |
| 215 | if (iterCarriedValSet.contains(Ptr: value)) |
| 216 | return true; |
| 217 | |
| 218 | for (Operation *op : slice) |
| 219 | for (Value operand : op->getOperands()) |
| 220 | if (iterCarriedValSet.contains(Ptr: operand)) |
| 221 | return true; |
| 222 | |
| 223 | return false; |
| 224 | } |
| 225 | |
| 226 | /// Utility to match a generic reduction given a list of iteration-carried |
| 227 | /// arguments, `iterCarriedArgs` and the position of the potential reduction |
| 228 | /// argument within the list, `redPos`. If a reduction is matched, returns the |
| 229 | /// reduced value and the topologically-sorted list of combiner operations |
| 230 | /// involved in the reduction. Otherwise, returns a null value. |
| 231 | /// |
| 232 | /// The matching algorithm relies on the following invariants, which are subject |
| 233 | /// to change: |
| 234 | /// 1. The first combiner operation must be a binary operation with the |
| 235 | /// iteration-carried value and the reduced value as operands. |
| 236 | /// 2. The iteration-carried value and combiner operations must be side |
| 237 | /// effect-free, have single result and a single use. |
| 238 | /// 3. Combiner operations must be immediately nested in the region op |
| 239 | /// performing the reduction. |
| 240 | /// 4. Reduction def-use chain must end in a terminator op that yields the |
| 241 | /// next iteration/output values in the same order as the iteration-carried |
| 242 | /// values in `iterCarriedArgs`. |
| 243 | /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values |
| 244 | /// of the region op performing the reduction. |
| 245 | /// |
| 246 | /// This utility is generic enough to detect reductions involving multiple |
| 247 | /// combiner operations (disabled for now) across multiple dialects, including |
| 248 | /// Linalg, Affine and SCF. For the sake of genericity, it does not return |
| 249 | /// specific enum values for the combiner operations since its goal is also |
| 250 | /// matching reductions without pre-defined semantics in core MLIR. It's up to |
| 251 | /// each client to make sense out of the list of combiner operations. It's also |
| 252 | /// up to each client to check for additional invariants on the expected |
| 253 | /// reductions not covered by this generic matching. |
| 254 | Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs, |
| 255 | unsigned redPos, |
| 256 | SmallVectorImpl<Operation *> &combinerOps) { |
| 257 | assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds" ); |
| 258 | |
| 259 | BlockArgument redCarriedVal = iterCarriedArgs[redPos]; |
| 260 | if (!redCarriedVal.hasOneUse()) |
| 261 | return nullptr; |
| 262 | |
| 263 | // For now, the first combiner op must be a binary op. |
| 264 | Operation *combinerOp = *redCarriedVal.getUsers().begin(); |
| 265 | if (combinerOp->getNumOperands() != 2) |
| 266 | return nullptr; |
| 267 | Value reducedVal = combinerOp->getOperand(idx: 0) == redCarriedVal |
| 268 | ? combinerOp->getOperand(idx: 1) |
| 269 | : combinerOp->getOperand(idx: 0); |
| 270 | |
| 271 | Operation *redRegionOp = |
| 272 | iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); |
| 273 | if (dependsOnCarriedVals(value: reducedVal, iterCarriedArgs, ancestorOp: redRegionOp)) |
| 274 | return nullptr; |
| 275 | |
| 276 | // Traverse the def-use chain starting from the first combiner op until a |
| 277 | // terminator is found. Gather all the combiner ops along the way in |
| 278 | // topological order. |
| 279 | while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) { |
| 280 | if (!isMemoryEffectFree(op: combinerOp) || combinerOp->getNumResults() != 1 || |
| 281 | !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp) |
| 282 | return nullptr; |
| 283 | |
| 284 | combinerOps.push_back(Elt: combinerOp); |
| 285 | combinerOp = *combinerOp->getUsers().begin(); |
| 286 | } |
| 287 | |
| 288 | // Limit matching to single combiner op until we can properly test reductions |
| 289 | // involving multiple combiners. |
| 290 | if (combinerOps.size() != 1) |
| 291 | return nullptr; |
| 292 | |
| 293 | // Check that the yielded value is in the same position as in |
| 294 | // `iterCarriedArgs`. |
| 295 | Operation *terminatorOp = combinerOp; |
| 296 | if (terminatorOp->getOperand(idx: redPos) != combinerOps.back()->getResults()[0]) |
| 297 | return nullptr; |
| 298 | |
| 299 | return reducedVal; |
| 300 | } |
| 301 | |