1 | //===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements loop tiling on parallel loops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
19 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
20 | |
21 | namespace mlir { |
22 | #define GEN_PASS_DEF_SCFPARALLELLOOPTILING |
23 | #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" |
24 | } // namespace mlir |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::scf; |
28 | |
29 | /// Tile a parallel loop of the form |
30 | /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) |
31 | /// step (%arg4, %arg5) |
32 | /// |
33 | /// into |
34 | /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) |
35 | /// step (%arg4*tileSize[0], |
36 | /// %arg5*tileSize[1]) |
37 | /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0) |
38 | /// min(%arg5*tileSize[1], %arg3-%i1)) |
39 | /// step (%arg4, %arg5) |
40 | /// |
41 | /// or, when no-min-max-bounds is true, into |
42 | /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) |
43 | /// step (%arg4*tileSize[0], |
44 | /// %arg5*tileSize[1]) |
45 | /// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0], |
46 | /// %arg5*tileSize[1]) |
47 | /// step (%arg4, %arg5) |
48 | /// %inbound = (%j0 * %arg4 + %i0 < %arg2) && |
49 | /// (%j1 * %arg5 + %i1 < %arg3) |
50 | /// scf.if (%inbound) |
51 | /// .... |
52 | /// |
53 | /// where the uses of %i0 and %i1 in the loop body are replaced by |
54 | /// %i0 + j0 and %i1 + %j1. |
55 | /// |
56 | /// The old loop is replaced with the new one. |
57 | std::pair<ParallelOp, ParallelOp> |
58 | mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, |
59 | bool noMinMaxBounds) { |
60 | OpBuilder b(op); |
61 | auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); |
62 | SmallVector<Value, 2> tileSizeConstants; |
63 | tileSizeConstants.reserve(op.getUpperBound().size()); |
64 | for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) { |
65 | if (i < tileSizes.size()) |
66 | tileSizeConstants.push_back( |
67 | b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i])); |
68 | else |
69 | // Just pick 1 for the remaining dimensions. |
70 | tileSizeConstants.push_back( |
71 | b.create<arith::ConstantIndexOp>(op.getLoc(), 1)); |
72 | } |
73 | |
74 | // Create the outer loop with adjusted steps. |
75 | SmallVector<Value, 2> newSteps; |
76 | newSteps.reserve(op.getStep().size()); |
77 | for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) { |
78 | newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step), |
79 | std::get<1>(step))); |
80 | } |
81 | auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(), |
82 | op.getUpperBound(), newSteps); |
83 | b.setInsertionPointToStart(outerLoop.getBody()); |
84 | |
85 | // Compute min(size, dim - offset) to avoid out-of-bounds accesses. |
86 | auto minMap = AffineMap::get( |
87 | /*dimCount=*/3, /*symbolCount=*/0, |
88 | {getAffineDimExpr(/*position=*/0, context: b.getContext()), |
89 | getAffineDimExpr(/*position=*/1, context: b.getContext()) - |
90 | getAffineDimExpr(/*position=*/2, context: b.getContext())}, |
91 | b.getContext()); |
92 | |
93 | // Create the inner loop with adjusted bounds. |
94 | SmallVector<Value, 2> newBounds; |
95 | newBounds.reserve(op.getUpperBound().size()); |
96 | bool needInboundCheck = false; |
97 | for (auto [lowerBound, upperBound, newStep, iv, step, tileSizeConstant] : |
98 | llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(), |
99 | outerLoop.getStep(), outerLoop.getInductionVars(), |
100 | op.getStep(), tileSizeConstants)) { |
101 | // Collect the statically known loop bounds |
102 | auto lowerBoundConstant = |
103 | dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp()); |
104 | auto upperBoundConstant = |
105 | dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp()); |
106 | auto stepConstant = |
107 | dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp()); |
108 | auto tileSize = |
109 | cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value(); |
110 | // If the loop bounds and the loop step are constant and if the number of |
111 | // loop iterations is an integer multiple of the tile size, we use a static |
112 | // bound for the inner loop. |
113 | if (lowerBoundConstant && upperBoundConstant && stepConstant) { |
114 | auto numIterations = llvm::divideCeil(upperBoundConstant.value() - |
115 | lowerBoundConstant.value(), |
116 | stepConstant.value()); |
117 | if (numIterations % tileSize == 0) { |
118 | newBounds.push_back(newStep); |
119 | continue; |
120 | } |
121 | } |
122 | |
123 | // For InboundCheck mode, just use the variable outer step |
124 | if (noMinMaxBounds) { |
125 | newBounds.push_back(newStep); |
126 | needInboundCheck = true; |
127 | continue; |
128 | } |
129 | |
130 | // Otherwise, we dynamically compute the bound for |
131 | // each iteration of the outer loop. |
132 | newBounds.push_back( |
133 | b.create<affine::AffineMinOp>(op.getLoc(), b.getIndexType(), minMap, |
134 | ValueRange{newStep, upperBound, iv})); |
135 | } |
136 | auto innerLoop = b.create<ParallelOp>( |
137 | op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds, |
138 | op.getStep()); |
139 | |
140 | if (noMinMaxBounds && needInboundCheck) { |
141 | b.setInsertionPointToStart(innerLoop.getBody()); |
142 | // Insert in-bound check |
143 | Value inbound = |
144 | b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1)); |
145 | for (auto [outerUpperBound, outerIV, innerIV, innerStep] : |
146 | llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(), |
147 | innerLoop.getInductionVars(), innerLoop.getStep())) { |
148 | // %in_bound = %in_bound && |
149 | // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) |
150 | Value index = b.create<arith::AddIOp>( |
151 | op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep), |
152 | outerIV); |
153 | Value dimInbound = b.create<arith::CmpIOp>( |
154 | op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound); |
155 | inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound); |
156 | } |
157 | auto ifInbound = b.create<IfOp>(op.getLoc(), |
158 | /*resultTypes*/ ArrayRef<Type>{}, inbound, |
159 | /*hasElseRegion*/ false); |
160 | ifInbound.getThenRegion().takeBody(op.getRegion()); |
161 | Block &thenBlock = ifInbound.getThenRegion().front(); |
162 | // Replace the scf.reduce terminator with an scf.yield terminator. |
163 | Operation *reduceOp = thenBlock.getTerminator(); |
164 | b.setInsertionPointToEnd(&thenBlock); |
165 | b.create<scf::YieldOp>(reduceOp->getLoc()); |
166 | reduceOp->erase(); |
167 | b.setInsertionPointToStart(innerLoop.getBody()); |
168 | for (const auto &ivs : llvm::enumerate(llvm::zip( |
169 | innerLoop.getInductionVars(), outerLoop.getInductionVars()))) { |
170 | auto newIndex = b.create<arith::AddIOp>( |
171 | op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value())); |
172 | thenBlock.getArgument(ivs.index()) |
173 | .replaceAllUsesExcept(newIndex, newIndex); |
174 | } |
175 | thenBlock.eraseArguments(start: 0, num: thenBlock.getNumArguments()); |
176 | } else { |
177 | innerLoop.getRegion().takeBody(op.getRegion()); |
178 | b.setInsertionPointToStart(innerLoop.getBody()); |
179 | for (auto ivs : llvm::zip(innerLoop.getInductionVars(), |
180 | outerLoop.getInductionVars())) { |
181 | Value innerIndex = std::get<0>(ivs); |
182 | auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs), |
183 | std::get<1>(ivs)); |
184 | innerIndex.replaceAllUsesExcept(newIndex, newIndex); |
185 | } |
186 | } |
187 | |
188 | op.erase(); |
189 | return std::make_pair(outerLoop, innerLoop); |
190 | } |
191 | |
192 | namespace { |
193 | struct ParallelLoopTiling |
194 | : public impl::SCFParallelLoopTilingBase<ParallelLoopTiling> { |
195 | ParallelLoopTiling() = default; |
196 | explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes, |
197 | bool noMinMaxBounds = false) { |
198 | this->tileSizes = tileSizes; |
199 | this->noMinMaxBounds = noMinMaxBounds; |
200 | } |
201 | |
202 | void runOnOperation() override { |
203 | for (auto tileSize : tileSizes) |
204 | if (tileSize == 0) { |
205 | mlir::emitError(mlir::UnknownLoc::get(&Pass::getContext()), |
206 | "tile size cannot be 0" ); |
207 | return signalPassFailure(); |
208 | } |
209 | auto *parentOp = getOperation(); |
210 | SmallVector<ParallelOp, 2> innermostPloops; |
211 | getInnermostParallelLoops(parentOp, innermostPloops); |
212 | for (ParallelOp ploop : innermostPloops) { |
213 | // FIXME: Add reduction support. |
214 | if (ploop.getNumReductions() == 0) |
215 | tileParallelLoop(ploop, tileSizes, noMinMaxBounds); |
216 | } |
217 | } |
218 | }; |
219 | } // namespace |
220 | |
221 | std::unique_ptr<Pass> |
222 | mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes, |
223 | bool noMinMaxBounds) { |
224 | return std::make_unique<ParallelLoopTiling>(args&: tileSizes, args&: noMinMaxBounds); |
225 | } |
226 | |