1//===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop tiling on parallel loops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19#include "mlir/Dialect/SCF/Utils/Utils.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_SCFPARALLELLOOPTILING
23#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::scf;
28
29/// Tile a parallel loop of the form
30/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
31/// step (%arg4, %arg5)
32///
33/// into
34/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
35/// step (%arg4*tileSize[0],
36/// %arg5*tileSize[1])
37/// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
38/// min(%arg5*tileSize[1], %arg3-%i1))
39/// step (%arg4, %arg5)
40///
41/// or, when no-min-max-bounds is true, into
42/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
43/// step (%arg4*tileSize[0],
44/// %arg5*tileSize[1])
45/// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0],
46/// %arg5*tileSize[1])
47/// step (%arg4, %arg5)
48/// %inbound = (%j0 * %arg4 + %i0 < %arg2) &&
49/// (%j1 * %arg5 + %i1 < %arg3)
50/// scf.if (%inbound)
51/// ....
52///
53/// where the uses of %i0 and %i1 in the loop body are replaced by
54/// %i0 + j0 and %i1 + %j1.
55///
56/// The old loop is replaced with the new one.
57std::pair<ParallelOp, ParallelOp>
58mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
59 bool noMinMaxBounds) {
60 OpBuilder b(op);
61 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
62 SmallVector<Value, 2> tileSizeConstants;
63 tileSizeConstants.reserve(op.getUpperBound().size());
64 for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) {
65 if (i < tileSizes.size())
66 tileSizeConstants.push_back(
67 b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i]));
68 else
69 // Just pick 1 for the remaining dimensions.
70 tileSizeConstants.push_back(
71 b.create<arith::ConstantIndexOp>(op.getLoc(), 1));
72 }
73
74 // Create the outer loop with adjusted steps.
75 SmallVector<Value, 2> newSteps;
76 newSteps.reserve(op.getStep().size());
77 for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) {
78 newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step),
79 std::get<1>(step)));
80 }
81 auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(),
82 op.getUpperBound(), newSteps);
83 b.setInsertionPointToStart(outerLoop.getBody());
84
85 // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
86 auto minMap = AffineMap::get(
87 /*dimCount=*/3, /*symbolCount=*/0,
88 {getAffineDimExpr(/*position=*/0, context: b.getContext()),
89 getAffineDimExpr(/*position=*/1, context: b.getContext()) -
90 getAffineDimExpr(/*position=*/2, context: b.getContext())},
91 b.getContext());
92
93 // Create the inner loop with adjusted bounds.
94 SmallVector<Value, 2> newBounds;
95 newBounds.reserve(op.getUpperBound().size());
96 bool needInboundCheck = false;
97 for (auto [lowerBound, upperBound, newStep, iv, step, tileSizeConstant] :
98 llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(),
99 outerLoop.getStep(), outerLoop.getInductionVars(),
100 op.getStep(), tileSizeConstants)) {
101 // Collect the statically known loop bounds
102 auto lowerBoundConstant =
103 dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
104 auto upperBoundConstant =
105 dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
106 auto stepConstant =
107 dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
108 auto tileSize =
109 cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
110 // If the loop bounds and the loop step are constant and if the number of
111 // loop iterations is an integer multiple of the tile size, we use a static
112 // bound for the inner loop.
113 if (lowerBoundConstant && upperBoundConstant && stepConstant) {
114 auto numIterations = llvm::divideCeil(upperBoundConstant.value() -
115 lowerBoundConstant.value(),
116 stepConstant.value());
117 if (numIterations % tileSize == 0) {
118 newBounds.push_back(newStep);
119 continue;
120 }
121 }
122
123 // For InboundCheck mode, just use the variable outer step
124 if (noMinMaxBounds) {
125 newBounds.push_back(newStep);
126 needInboundCheck = true;
127 continue;
128 }
129
130 // Otherwise, we dynamically compute the bound for
131 // each iteration of the outer loop.
132 newBounds.push_back(
133 b.create<affine::AffineMinOp>(op.getLoc(), b.getIndexType(), minMap,
134 ValueRange{newStep, upperBound, iv}));
135 }
136 auto innerLoop = b.create<ParallelOp>(
137 op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
138 op.getStep());
139
140 if (noMinMaxBounds && needInboundCheck) {
141 b.setInsertionPointToStart(innerLoop.getBody());
142 // Insert in-bound check
143 Value inbound =
144 b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
145 for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
146 llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
147 innerLoop.getInductionVars(), innerLoop.getStep())) {
148 // %in_bound = %in_bound &&
149 // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
150 Value index = b.create<arith::AddIOp>(
151 op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep),
152 outerIV);
153 Value dimInbound = b.create<arith::CmpIOp>(
154 op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound);
155 inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound);
156 }
157 auto ifInbound = b.create<IfOp>(op.getLoc(),
158 /*resultTypes*/ ArrayRef<Type>{}, inbound,
159 /*hasElseRegion*/ false);
160 ifInbound.getThenRegion().takeBody(op.getRegion());
161 Block &thenBlock = ifInbound.getThenRegion().front();
162 // Replace the scf.reduce terminator with an scf.yield terminator.
163 Operation *reduceOp = thenBlock.getTerminator();
164 b.setInsertionPointToEnd(&thenBlock);
165 b.create<scf::YieldOp>(reduceOp->getLoc());
166 reduceOp->erase();
167 b.setInsertionPointToStart(innerLoop.getBody());
168 for (const auto &ivs : llvm::enumerate(llvm::zip(
169 innerLoop.getInductionVars(), outerLoop.getInductionVars()))) {
170 auto newIndex = b.create<arith::AddIOp>(
171 op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value()));
172 thenBlock.getArgument(ivs.index())
173 .replaceAllUsesExcept(newIndex, newIndex);
174 }
175 thenBlock.eraseArguments(start: 0, num: thenBlock.getNumArguments());
176 } else {
177 innerLoop.getRegion().takeBody(op.getRegion());
178 b.setInsertionPointToStart(innerLoop.getBody());
179 for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
180 outerLoop.getInductionVars())) {
181 Value innerIndex = std::get<0>(ivs);
182 auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs),
183 std::get<1>(ivs));
184 innerIndex.replaceAllUsesExcept(newIndex, newIndex);
185 }
186 }
187
188 op.erase();
189 return std::make_pair(outerLoop, innerLoop);
190}
191
192namespace {
193struct ParallelLoopTiling
194 : public impl::SCFParallelLoopTilingBase<ParallelLoopTiling> {
195 ParallelLoopTiling() = default;
196 explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes,
197 bool noMinMaxBounds = false) {
198 this->tileSizes = tileSizes;
199 this->noMinMaxBounds = noMinMaxBounds;
200 }
201
202 void runOnOperation() override {
203 for (auto tileSize : tileSizes)
204 if (tileSize == 0) {
205 mlir::emitError(mlir::UnknownLoc::get(&Pass::getContext()),
206 "tile size cannot be 0");
207 return signalPassFailure();
208 }
209 auto *parentOp = getOperation();
210 SmallVector<ParallelOp, 2> innermostPloops;
211 getInnermostParallelLoops(parentOp, innermostPloops);
212 for (ParallelOp ploop : innermostPloops) {
213 // FIXME: Add reduction support.
214 if (ploop.getNumReductions() == 0)
215 tileParallelLoop(ploop, tileSizes, noMinMaxBounds);
216 }
217 }
218};
219} // namespace
220
221std::unique_ptr<Pass>
222mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,
223 bool noMinMaxBounds) {
224 return std::make_unique<ParallelLoopTiling>(args&: tileSizes, args&: noMinMaxBounds);
225}
226

source code of mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp