| 1 | //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements loop unrolling. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/Passes.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" |
| 16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 17 | #include "mlir/Dialect/Affine/LoopUtils.h" |
| 18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 19 | #include "mlir/IR/AffineExpr.h" |
| 20 | #include "mlir/IR/AffineMap.h" |
| 21 | #include "mlir/IR/Builders.h" |
| 22 | #include "llvm/ADT/DenseMap.h" |
| 23 | #include "llvm/Support/CommandLine.h" |
| 24 | #include "llvm/Support/Debug.h" |
| 25 | #include <optional> |
| 26 | |
| 27 | namespace mlir { |
| 28 | namespace affine { |
| 29 | #define GEN_PASS_DEF_AFFINELOOPUNROLL |
| 30 | #include "mlir/Dialect/Affine/Passes.h.inc" |
| 31 | } // namespace affine |
| 32 | } // namespace mlir |
| 33 | |
| 34 | #define DEBUG_TYPE "affine-loop-unroll" |
| 35 | |
| 36 | using namespace mlir; |
| 37 | using namespace mlir::affine; |
| 38 | |
| 39 | namespace { |
| 40 | |
| 41 | // TODO: this is really a test pass and should be moved out of dialect |
| 42 | // transforms. |
| 43 | |
| 44 | /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a |
| 45 | /// full unroll threshold was specified, in which case, fully unrolls all loops |
| 46 | /// with trip count less than the specified threshold. The latter is for testing |
| 47 | /// purposes, especially for testing outer loop unrolling. |
| 48 | struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> { |
| 49 | // Callback to obtain unroll factors; if this has a callable target, takes |
| 50 | // precedence over command-line argument or passed argument. |
| 51 | const std::function<unsigned(AffineForOp)> getUnrollFactor; |
| 52 | |
| 53 | LoopUnroll() : getUnrollFactor(nullptr) {} |
| 54 | LoopUnroll(const LoopUnroll &other) |
| 55 | |
| 56 | = default; |
| 57 | explicit LoopUnroll( |
| 58 | std::optional<unsigned> unrollFactor = std::nullopt, |
| 59 | bool unrollUpToFactor = false, bool unrollFull = false, |
| 60 | const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) |
| 61 | : getUnrollFactor(getUnrollFactor) { |
| 62 | if (unrollFactor) |
| 63 | this->unrollFactor = *unrollFactor; |
| 64 | this->unrollUpToFactor = unrollUpToFactor; |
| 65 | this->unrollFull = unrollFull; |
| 66 | } |
| 67 | |
| 68 | void runOnOperation() override; |
| 69 | |
| 70 | /// Unroll this for op. Returns failure if nothing was done. |
| 71 | LogicalResult runOnAffineForOp(AffineForOp forOp); |
| 72 | }; |
| 73 | } // namespace |
| 74 | |
| 75 | /// Returns true if no other affine.for ops are nested within `op`. |
| 76 | static bool isInnermostAffineForOp(AffineForOp op) { |
| 77 | return !op.getBody() |
| 78 | ->walk([&](AffineForOp nestedForOp) { |
| 79 | return WalkResult::interrupt(); |
| 80 | }) |
| 81 | .wasInterrupted(); |
| 82 | } |
| 83 | |
| 84 | /// Gathers loops that have no affine.for's nested within. |
| 85 | static void gatherInnermostLoops(FunctionOpInterface f, |
| 86 | SmallVectorImpl<AffineForOp> &loops) { |
| 87 | f.walk([&](AffineForOp forOp) { |
| 88 | if (isInnermostAffineForOp(forOp)) |
| 89 | loops.push_back(forOp); |
| 90 | }); |
| 91 | } |
| 92 | |
| 93 | void LoopUnroll::runOnOperation() { |
| 94 | FunctionOpInterface func = getOperation(); |
| 95 | if (func.isExternal()) |
| 96 | return; |
| 97 | |
| 98 | if (unrollFull && unrollFullThreshold.hasValue()) { |
| 99 | // Store short loops as we walk. |
| 100 | SmallVector<AffineForOp, 4> loops; |
| 101 | |
| 102 | // Gathers all loops with trip count <= minTripCount. Do a post order walk |
| 103 | // so that loops are gathered from innermost to outermost (or else |
| 104 | // unrolling an outer one may delete gathered inner ones). |
| 105 | getOperation().walk([&](AffineForOp forOp) { |
| 106 | std::optional<uint64_t> tripCount = getConstantTripCount(forOp); |
| 107 | if (tripCount && *tripCount <= unrollFullThreshold) |
| 108 | loops.push_back(forOp); |
| 109 | }); |
| 110 | for (auto forOp : loops) |
| 111 | (void)loopUnrollFull(forOp); |
| 112 | return; |
| 113 | } |
| 114 | |
| 115 | // If the call back is provided, we will recurse until no loops are found. |
| 116 | SmallVector<AffineForOp, 4> loops; |
| 117 | for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { |
| 118 | loops.clear(); |
| 119 | gatherInnermostLoops(func, loops); |
| 120 | if (loops.empty()) |
| 121 | break; |
| 122 | bool unrolled = false; |
| 123 | for (auto forOp : loops) |
| 124 | unrolled |= succeeded(runOnAffineForOp(forOp)); |
| 125 | if (!unrolled) |
| 126 | // Break out if nothing was unrolled. |
| 127 | break; |
| 128 | } |
| 129 | } |
| 130 | |
| 131 | /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, |
| 132 | /// failure otherwise. The default unroll factor is 4. |
| 133 | LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { |
| 134 | // Use the function callback if one was provided. |
| 135 | if (getUnrollFactor) |
| 136 | return loopUnrollByFactor(forOp, getUnrollFactor(forOp), |
| 137 | /*annotateFn=*/nullptr, cleanUpUnroll); |
| 138 | // Unroll completely if full loop unroll was specified. |
| 139 | if (unrollFull) |
| 140 | return loopUnrollFull(forOp); |
| 141 | // Otherwise, unroll by the given unroll factor. |
| 142 | if (unrollUpToFactor) |
| 143 | return loopUnrollUpToFactor(forOp, unrollFactor); |
| 144 | return loopUnrollByFactor(forOp, unrollFactor, /*annotateFn=*/nullptr, |
| 145 | cleanUpUnroll); |
| 146 | } |
| 147 | |
| 148 | std::unique_ptr<InterfacePass<FunctionOpInterface>> |
| 149 | mlir::affine::createLoopUnrollPass( |
| 150 | int unrollFactor, bool unrollUpToFactor, bool unrollFull, |
| 151 | const std::function<unsigned(AffineForOp)> &getUnrollFactor) { |
| 152 | return std::make_unique<LoopUnroll>( |
| 153 | args: unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor), |
| 154 | args&: unrollUpToFactor, args&: unrollFull, args: getUnrollFactor); |
| 155 | } |
| 156 | |