| 1 | //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements loop fusion on parallel loops. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
| 14 | |
| 15 | #include "mlir/Analysis/AliasAnalysis.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 18 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 19 | #include "mlir/IR/Builders.h" |
| 20 | #include "mlir/IR/IRMapping.h" |
| 21 | #include "mlir/IR/OpDefinition.h" |
| 22 | #include "mlir/IR/OperationSupport.h" |
| 23 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 24 | |
| 25 | namespace mlir { |
| 26 | #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION |
| 27 | #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" |
| 28 | } // namespace mlir |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::scf; |
| 32 | |
| 33 | /// Verify there are no nested ParallelOps. |
| 34 | static bool hasNestedParallelOp(ParallelOp ploop) { |
| 35 | auto walkResult = |
| 36 | ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); |
| 37 | return walkResult.wasInterrupted(); |
| 38 | } |
| 39 | |
| 40 | /// Verify equal iteration spaces. |
| 41 | static bool equalIterationSpaces(ParallelOp firstPloop, |
| 42 | ParallelOp secondPloop) { |
| 43 | if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) |
| 44 | return false; |
| 45 | |
| 46 | auto matchOperands = [&](const OperandRange &lhs, |
| 47 | const OperandRange &rhs) -> bool { |
| 48 | // TODO: Extend this to support aliases and equal constants. |
| 49 | return std::equal(first1: lhs.begin(), last1: lhs.end(), first2: rhs.begin()); |
| 50 | }; |
| 51 | return matchOperands(firstPloop.getLowerBound(), |
| 52 | secondPloop.getLowerBound()) && |
| 53 | matchOperands(firstPloop.getUpperBound(), |
| 54 | secondPloop.getUpperBound()) && |
| 55 | matchOperands(firstPloop.getStep(), secondPloop.getStep()); |
| 56 | } |
| 57 | |
| 58 | /// Checks if the parallel loops have mixed access to the same buffers. Returns |
| 59 | /// `true` if the first parallel loop writes to the same indices that the second |
| 60 | /// loop reads. |
| 61 | static bool haveNoReadsAfterWriteExceptSameIndex( |
| 62 | ParallelOp firstPloop, ParallelOp secondPloop, |
| 63 | const IRMapping &firstToSecondPloopIndices, |
| 64 | llvm::function_ref<bool(Value, Value)> mayAlias) { |
| 65 | DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; |
| 66 | SmallVector<Value> bufferStoresVec; |
| 67 | firstPloop.getBody()->walk([&](memref::StoreOp store) { |
| 68 | bufferStores[store.getMemRef()].push_back(store.getIndices()); |
| 69 | bufferStoresVec.emplace_back(store.getMemRef()); |
| 70 | }); |
| 71 | auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { |
| 72 | Value loadMem = load.getMemRef(); |
| 73 | // Stop if the memref is defined in secondPloop body. Careful alias analysis |
| 74 | // is needed. |
| 75 | auto *memrefDef = loadMem.getDefiningOp(); |
| 76 | if (memrefDef && memrefDef->getBlock() == load->getBlock()) |
| 77 | return WalkResult::interrupt(); |
| 78 | |
| 79 | for (Value store : bufferStoresVec) |
| 80 | if (store != loadMem && mayAlias(store, loadMem)) |
| 81 | return WalkResult::interrupt(); |
| 82 | |
| 83 | auto write = bufferStores.find(Val: loadMem); |
| 84 | if (write == bufferStores.end()) |
| 85 | return WalkResult::advance(); |
| 86 | |
| 87 | // Check that at last one store was retrieved |
| 88 | if (write->second.empty()) |
| 89 | return WalkResult::interrupt(); |
| 90 | |
| 91 | auto storeIndices = write->second.front(); |
| 92 | |
| 93 | // Multiple writes to the same memref are allowed only on the same indices |
| 94 | for (const auto &othStoreIndices : write->second) { |
| 95 | if (othStoreIndices != storeIndices) |
| 96 | return WalkResult::interrupt(); |
| 97 | } |
| 98 | |
| 99 | // Check that the load indices of secondPloop coincide with store indices of |
| 100 | // firstPloop for the same memrefs. |
| 101 | auto loadIndices = load.getIndices(); |
| 102 | if (storeIndices.size() != loadIndices.size()) |
| 103 | return WalkResult::interrupt(); |
| 104 | for (int i = 0, e = storeIndices.size(); i < e; ++i) { |
| 105 | if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != |
| 106 | loadIndices[i]) { |
| 107 | auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); |
| 108 | auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); |
| 109 | if (storeIndexDefOp && loadIndexDefOp) { |
| 110 | if (!isMemoryEffectFree(storeIndexDefOp)) |
| 111 | return WalkResult::interrupt(); |
| 112 | if (!isMemoryEffectFree(loadIndexDefOp)) |
| 113 | return WalkResult::interrupt(); |
| 114 | if (!OperationEquivalence::isEquivalentTo( |
| 115 | storeIndexDefOp, loadIndexDefOp, |
| 116 | [&](Value storeIndex, Value loadIndex) { |
| 117 | if (firstToSecondPloopIndices.lookupOrDefault(from: storeIndex) != |
| 118 | firstToSecondPloopIndices.lookupOrDefault(from: loadIndex)) |
| 119 | return failure(); |
| 120 | else |
| 121 | return success(); |
| 122 | }, |
| 123 | /*markEquivalent=*/nullptr, |
| 124 | OperationEquivalence::Flags::IgnoreLocations)) { |
| 125 | return WalkResult::interrupt(); |
| 126 | } |
| 127 | } else { |
| 128 | return WalkResult::interrupt(); |
| 129 | } |
| 130 | } |
| 131 | } |
| 132 | return WalkResult::advance(); |
| 133 | }); |
| 134 | return !walkResult.wasInterrupted(); |
| 135 | } |
| 136 | |
| 137 | /// Analyzes dependencies in the most primitive way by checking simple read and |
| 138 | /// write patterns. |
| 139 | static LogicalResult |
| 140 | verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, |
| 141 | const IRMapping &firstToSecondPloopIndices, |
| 142 | llvm::function_ref<bool(Value, Value)> mayAlias) { |
| 143 | if (!haveNoReadsAfterWriteExceptSameIndex( |
| 144 | firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) |
| 145 | return failure(); |
| 146 | |
| 147 | IRMapping secondToFirstPloopIndices; |
| 148 | secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), |
| 149 | firstPloop.getBody()->getArguments()); |
| 150 | return success(haveNoReadsAfterWriteExceptSameIndex( |
| 151 | secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); |
| 152 | } |
| 153 | |
| 154 | static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, |
| 155 | const IRMapping &firstToSecondPloopIndices, |
| 156 | llvm::function_ref<bool(Value, Value)> mayAlias) { |
| 157 | return !hasNestedParallelOp(firstPloop) && |
| 158 | !hasNestedParallelOp(secondPloop) && |
| 159 | equalIterationSpaces(firstPloop, secondPloop) && |
| 160 | succeeded(verifyDependencies(firstPloop, secondPloop, |
| 161 | firstToSecondPloopIndices, mayAlias)); |
| 162 | } |
| 163 | |
| 164 | /// Prepends operations of firstPloop's body into secondPloop's body. |
| 165 | /// Updates secondPloop with new loop. |
| 166 | static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, |
| 167 | OpBuilder builder, |
| 168 | llvm::function_ref<bool(Value, Value)> mayAlias) { |
| 169 | Block *block1 = firstPloop.getBody(); |
| 170 | Block *block2 = secondPloop.getBody(); |
| 171 | IRMapping firstToSecondPloopIndices; |
| 172 | firstToSecondPloopIndices.map(from: block1->getArguments(), to: block2->getArguments()); |
| 173 | |
| 174 | if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, |
| 175 | mayAlias)) |
| 176 | return; |
| 177 | |
| 178 | DominanceInfo dom; |
| 179 | // We are fusing first loop into second, make sure there are no users of the |
| 180 | // first loop results between loops. |
| 181 | for (Operation *user : firstPloop->getUsers()) |
| 182 | if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) |
| 183 | return; |
| 184 | |
| 185 | ValueRange inits1 = firstPloop.getInitVals(); |
| 186 | ValueRange inits2 = secondPloop.getInitVals(); |
| 187 | |
| 188 | SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); |
| 189 | newInitVars.append(in_start: inits2.begin(), in_end: inits2.end()); |
| 190 | |
| 191 | IRRewriter b(builder); |
| 192 | b.setInsertionPoint(secondPloop); |
| 193 | auto newSecondPloop = b.create<ParallelOp>( |
| 194 | secondPloop.getLoc(), secondPloop.getLowerBound(), |
| 195 | secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); |
| 196 | |
| 197 | Block *newBlock = newSecondPloop.getBody(); |
| 198 | auto term1 = cast<ReduceOp>(block1->getTerminator()); |
| 199 | auto term2 = cast<ReduceOp>(block2->getTerminator()); |
| 200 | |
| 201 | b.inlineBlockBefore(source: block2, dest: newBlock, before: newBlock->begin(), |
| 202 | argValues: newBlock->getArguments()); |
| 203 | b.inlineBlockBefore(source: block1, dest: newBlock, before: newBlock->begin(), |
| 204 | argValues: newBlock->getArguments()); |
| 205 | |
| 206 | ValueRange results = newSecondPloop.getResults(); |
| 207 | if (!results.empty()) { |
| 208 | b.setInsertionPointToEnd(newBlock); |
| 209 | |
| 210 | ValueRange reduceArgs1 = term1.getOperands(); |
| 211 | ValueRange reduceArgs2 = term2.getOperands(); |
| 212 | SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); |
| 213 | newReduceArgs.append(in_start: reduceArgs2.begin(), in_end: reduceArgs2.end()); |
| 214 | |
| 215 | auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs); |
| 216 | |
| 217 | for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( |
| 218 | term1.getReductions(), term2.getReductions()))) { |
| 219 | Block &oldRedBlock = reg.front(); |
| 220 | Block &newRedBlock = newReduceOp.getReductions()[i].front(); |
| 221 | b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), |
| 222 | newRedBlock.getArguments()); |
| 223 | } |
| 224 | |
| 225 | firstPloop.replaceAllUsesWith(results.take_front(n: inits1.size())); |
| 226 | secondPloop.replaceAllUsesWith(results.take_back(n: inits2.size())); |
| 227 | } |
| 228 | term1->erase(); |
| 229 | term2->erase(); |
| 230 | firstPloop.erase(); |
| 231 | secondPloop.erase(); |
| 232 | secondPloop = newSecondPloop; |
| 233 | } |
| 234 | |
| 235 | void mlir::scf::naivelyFuseParallelOps( |
| 236 | Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) { |
| 237 | OpBuilder b(region); |
| 238 | // Consider every single block and attempt to fuse adjacent loops. |
| 239 | SmallVector<SmallVector<ParallelOp>, 1> ploopChains; |
| 240 | for (auto &block : region) { |
| 241 | ploopChains.clear(); |
| 242 | ploopChains.push_back({}); |
| 243 | |
| 244 | // Not using `walk()` to traverse only top-level parallel loops and also |
| 245 | // make sure that there are no side-effecting ops between the parallel |
| 246 | // loops. |
| 247 | bool noSideEffects = true; |
| 248 | for (auto &op : block) { |
| 249 | if (auto ploop = dyn_cast<ParallelOp>(op)) { |
| 250 | if (noSideEffects) { |
| 251 | ploopChains.back().push_back(ploop); |
| 252 | } else { |
| 253 | ploopChains.push_back({ploop}); |
| 254 | noSideEffects = true; |
| 255 | } |
| 256 | continue; |
| 257 | } |
| 258 | // TODO: Handle region side effects properly. |
| 259 | noSideEffects &= isMemoryEffectFree(op: &op) && op.getNumRegions() == 0; |
| 260 | } |
| 261 | for (MutableArrayRef<ParallelOp> ploops : ploopChains) { |
| 262 | for (int i = 0, e = ploops.size(); i + 1 < e; ++i) |
| 263 | fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); |
| 264 | } |
| 265 | } |
| 266 | } |
| 267 | |
| 268 | namespace { |
| 269 | struct ParallelLoopFusion |
| 270 | : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { |
| 271 | void runOnOperation() override { |
| 272 | auto &AA = getAnalysis<AliasAnalysis>(); |
| 273 | |
| 274 | auto mayAlias = [&](Value val1, Value val2) -> bool { |
| 275 | return !AA.alias(val1, val2).isNo(); |
| 276 | }; |
| 277 | |
| 278 | getOperation()->walk([&](Operation *child) { |
| 279 | for (Region ®ion : child->getRegions()) |
| 280 | naivelyFuseParallelOps(region, mayAlias); |
| 281 | }); |
| 282 | } |
| 283 | }; |
| 284 | } // namespace |
| 285 | |
| 286 | std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { |
| 287 | return std::make_unique<ParallelLoopFusion>(); |
| 288 | } |
| 289 | |