1 | //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements loop unrolling. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" |
16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
17 | #include "mlir/Dialect/Affine/LoopUtils.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/IR/AffineExpr.h" |
20 | #include "mlir/IR/AffineMap.h" |
21 | #include "mlir/IR/Builders.h" |
22 | #include "llvm/ADT/DenseMap.h" |
23 | #include "llvm/Support/CommandLine.h" |
24 | #include "llvm/Support/Debug.h" |
25 | #include <optional> |
26 | |
27 | namespace mlir { |
28 | namespace affine { |
29 | #define GEN_PASS_DEF_AFFINELOOPUNROLL |
30 | #include "mlir/Dialect/Affine/Passes.h.inc" |
31 | } // namespace affine |
32 | } // namespace mlir |
33 | |
34 | #define DEBUG_TYPE "affine-loop-unroll" |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::affine; |
38 | |
39 | namespace { |
40 | |
41 | // TODO: this is really a test pass and should be moved out of dialect |
42 | // transforms. |
43 | |
44 | /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a |
45 | /// full unroll threshold was specified, in which case, fully unrolls all loops |
46 | /// with trip count less than the specified threshold. The latter is for testing |
47 | /// purposes, especially for testing outer loop unrolling. |
48 | struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> { |
49 | // Callback to obtain unroll factors; if this has a callable target, takes |
50 | // precedence over command-line argument or passed argument. |
51 | const std::function<unsigned(AffineForOp)> getUnrollFactor; |
52 | |
53 | LoopUnroll() : getUnrollFactor(nullptr) {} |
54 | LoopUnroll(const LoopUnroll &other) |
55 | |
56 | = default; |
57 | explicit LoopUnroll( |
58 | std::optional<unsigned> unrollFactor = std::nullopt, |
59 | bool unrollUpToFactor = false, bool unrollFull = false, |
60 | const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) |
61 | : getUnrollFactor(getUnrollFactor) { |
62 | if (unrollFactor) |
63 | this->unrollFactor = *unrollFactor; |
64 | this->unrollUpToFactor = unrollUpToFactor; |
65 | this->unrollFull = unrollFull; |
66 | } |
67 | |
68 | void runOnOperation() override; |
69 | |
70 | /// Unroll this for op. Returns failure if nothing was done. |
71 | LogicalResult runOnAffineForOp(AffineForOp forOp); |
72 | }; |
73 | } // namespace |
74 | |
75 | /// Returns true if no other affine.for ops are nested within `op`. |
76 | static bool isInnermostAffineForOp(AffineForOp op) { |
77 | return !op.getBody() |
78 | ->walk([&](AffineForOp nestedForOp) { |
79 | return WalkResult::interrupt(); |
80 | }) |
81 | .wasInterrupted(); |
82 | } |
83 | |
84 | /// Gathers loops that have no affine.for's nested within. |
85 | static void gatherInnermostLoops(func::FuncOp f, |
86 | SmallVectorImpl<AffineForOp> &loops) { |
87 | f.walk([&](AffineForOp forOp) { |
88 | if (isInnermostAffineForOp(forOp)) |
89 | loops.push_back(forOp); |
90 | }); |
91 | } |
92 | |
93 | void LoopUnroll::runOnOperation() { |
94 | func::FuncOp func = getOperation(); |
95 | if (func.isExternal()) |
96 | return; |
97 | |
98 | if (unrollFull && unrollFullThreshold.hasValue()) { |
99 | // Store short loops as we walk. |
100 | SmallVector<AffineForOp, 4> loops; |
101 | |
102 | // Gathers all loops with trip count <= minTripCount. Do a post order walk |
103 | // so that loops are gathered from innermost to outermost (or else unrolling |
104 | // an outer one may delete gathered inner ones). |
105 | getOperation().walk([&](AffineForOp forOp) { |
106 | std::optional<uint64_t> tripCount = getConstantTripCount(forOp); |
107 | if (tripCount && *tripCount <= unrollFullThreshold) |
108 | loops.push_back(forOp); |
109 | }); |
110 | for (auto forOp : loops) |
111 | (void)loopUnrollFull(forOp); |
112 | return; |
113 | } |
114 | |
115 | // If the call back is provided, we will recurse until no loops are found. |
116 | SmallVector<AffineForOp, 4> loops; |
117 | for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { |
118 | loops.clear(); |
119 | gatherInnermostLoops(func, loops); |
120 | if (loops.empty()) |
121 | break; |
122 | bool unrolled = false; |
123 | for (auto forOp : loops) |
124 | unrolled |= succeeded(runOnAffineForOp(forOp)); |
125 | if (!unrolled) |
126 | // Break out if nothing was unrolled. |
127 | break; |
128 | } |
129 | } |
130 | |
131 | /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, |
132 | /// failure otherwise. The default unroll factor is 4. |
133 | LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { |
134 | // Use the function callback if one was provided. |
135 | if (getUnrollFactor) |
136 | return loopUnrollByFactor(forOp, getUnrollFactor(forOp), |
137 | /*annotateFn=*/nullptr, cleanUpUnroll); |
138 | // Unroll completely if full loop unroll was specified. |
139 | if (unrollFull) |
140 | return loopUnrollFull(forOp); |
141 | // Otherwise, unroll by the given unroll factor. |
142 | if (unrollUpToFactor) |
143 | return loopUnrollUpToFactor(forOp, unrollFactor); |
144 | return loopUnrollByFactor(forOp, unrollFactor, /*annotateFn=*/nullptr, |
145 | cleanUpUnroll); |
146 | } |
147 | |
148 | std::unique_ptr<OperationPass<func::FuncOp>> mlir::affine::createLoopUnrollPass( |
149 | int unrollFactor, bool unrollUpToFactor, bool unrollFull, |
150 | const std::function<unsigned(AffineForOp)> &getUnrollFactor) { |
151 | return std::make_unique<LoopUnroll>( |
152 | args: unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor), |
153 | args&: unrollUpToFactor, args&: unrollFull, args: getUnrollFactor); |
154 | } |
155 | |