1//===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous loop transformation routines.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Utils/Utils.h"
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/IR/BuiltinOps.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/OpDefinition.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Interfaces/SideEffectInterfaces.h"
25#include "mlir/Transforms/RegionUtils.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/SmallPtrSet.h"
29#include "llvm/ADT/SmallVector.h"
30#include "llvm/Support/Debug.h"
31#include "llvm/Support/MathExtras.h"
32#include <cstdint>
33
34using namespace mlir;
35
36#define DEBUG_TYPE "scf-utils"
37#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
38#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
39
40SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
41 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
42 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
43 bool replaceIterOperandsUsesInLoop) {
44 if (loopNest.empty())
45 return {};
46 // This method is recursive (to make it more readable). Adding an
47 // assertion here to limit the recursion. (See
48 // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
49 assert(loopNest.size() <= 10 &&
50 "exceeded recursion limit when yielding value from loop nest");
51
52 // To yield a value from a perfectly nested loop nest, the following
53 // pattern needs to be created, i.e. starting with
54 //
55 // ```mlir
56 // scf.for .. {
57 // scf.for .. {
58 // scf.for .. {
59 // %value = ...
60 // }
61 // }
62 // }
63 // ```
64 //
65 // needs to be modified to
66 //
67 // ```mlir
68 // %0 = scf.for .. iter_args(%arg0 = %init) {
69 // %1 = scf.for .. iter_args(%arg1 = %arg0) {
70 // %2 = scf.for .. iter_args(%arg2 = %arg1) {
71 // %value = ...
72 // scf.yield %value
73 // }
74 // scf.yield %2
75 // }
76 // scf.yield %1
77 // }
78 // ```
79 //
80 // The inner most loop is handled using the `replaceWithAdditionalYields`
81 // that works on a single loop.
82 if (loopNest.size() == 1) {
83 auto innerMostLoop =
84 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
85 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
86 newYieldValuesFn));
87 return {innerMostLoop};
88 }
89 // The outer loops are modified by calling this method recursively
90 // - The return value of the inner loop is the value yielded by this loop.
91 // - The region iter args of this loop are the init_args for the inner loop.
92 SmallVector<scf::ForOp> newLoopNest;
93 NewYieldValuesFn fn =
94 [&](OpBuilder &innerBuilder, Location loc,
95 ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
96 newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
97 innerNewBBArgs, newYieldValuesFn,
98 replaceIterOperandsUsesInLoop);
99 return llvm::to_vector(llvm::map_range(
100 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
101 [](OpResult r) -> Value { return r; }));
102 };
103 scf::ForOp outerMostLoop =
104 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
105 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
106 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
107 return newLoopNest;
108}
109
110/// Outline a region with a single block into a new FuncOp.
111/// Assumes the FuncOp result types is the type of the yielded operands of the
112/// single block. This constraint makes it easy to determine the result.
113/// This method also clones the `arith::ConstantIndexOp` at the start of
114/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
115/// provided, it will be set to point to the operation that calls the outlined
116/// function.
117// TODO: support more than single-block regions.
118// TODO: more flexible constant handling.
119FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
120 Location loc,
121 Region &region,
122 StringRef funcName,
123 func::CallOp *callOp) {
124 assert(!funcName.empty() && "funcName cannot be empty");
125 if (!region.hasOneBlock())
126 return failure();
127
128 Block *originalBlock = &region.front();
129 Operation *originalTerminator = originalBlock->getTerminator();
130
131 // Outline before current function.
132 OpBuilder::InsertionGuard g(rewriter);
133 rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
134
135 SetVector<Value> captures;
136 getUsedValuesDefinedAbove(regions: region, values&: captures);
137
138 ValueRange outlinedValues(captures.getArrayRef());
139 SmallVector<Type> outlinedFuncArgTypes;
140 SmallVector<Location> outlinedFuncArgLocs;
141 // Region's arguments are exactly the first block's arguments as per
142 // Region::getArguments().
143 // Func's arguments are cat(regions's arguments, captures arguments).
144 for (BlockArgument arg : region.getArguments()) {
145 outlinedFuncArgTypes.push_back(Elt: arg.getType());
146 outlinedFuncArgLocs.push_back(Elt: arg.getLoc());
147 }
148 for (Value value : outlinedValues) {
149 outlinedFuncArgTypes.push_back(Elt: value.getType());
150 outlinedFuncArgLocs.push_back(Elt: value.getLoc());
151 }
152 FunctionType outlinedFuncType =
153 FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
154 originalTerminator->getOperandTypes());
155 auto outlinedFunc =
156 rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
157 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
158
159 // Merge blocks while replacing the original block operands.
160 // Warning: `mergeBlocks` erases the original block, reconstruct it later.
161 int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
162 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
163 {
164 OpBuilder::InsertionGuard g(rewriter);
165 rewriter.setInsertionPointToEnd(outlinedFuncBody);
166 rewriter.mergeBlocks(
167 source: originalBlock, dest: outlinedFuncBody,
168 argValues: outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
169 // Explicitly set up a new ReturnOp terminator.
170 rewriter.setInsertionPointToEnd(outlinedFuncBody);
171 rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
172 originalTerminator->getOperands());
173 }
174
175 // Reconstruct the block that was deleted and add a
176 // terminator(call_results).
177 Block *newBlock = rewriter.createBlock(
178 parent: &region, insertPt: region.begin(),
179 argTypes: TypeRange{outlinedFuncArgTypes}.take_front(n: numOriginalBlockArguments),
180 locs: ArrayRef<Location>(outlinedFuncArgLocs)
181 .take_front(N: numOriginalBlockArguments));
182 {
183 OpBuilder::InsertionGuard g(rewriter);
184 rewriter.setInsertionPointToEnd(newBlock);
185 SmallVector<Value> callValues;
186 llvm::append_range(C&: callValues, R: newBlock->getArguments());
187 llvm::append_range(C&: callValues, R&: outlinedValues);
188 auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
189 if (callOp)
190 *callOp = call;
191
192 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
193 // Clone `originalTerminator` to take the callOp results then erase it from
194 // `outlinedFuncBody`.
195 IRMapping bvm;
196 bvm.map(originalTerminator->getOperands(), call->getResults());
197 rewriter.clone(op&: *originalTerminator, mapper&: bvm);
198 rewriter.eraseOp(op: originalTerminator);
199 }
200
201 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
202 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
203 for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
204 outlinedValues.size()))) {
205 Value orig = std::get<0>(it);
206 Value repl = std::get<1>(it);
207 {
208 OpBuilder::InsertionGuard g(rewriter);
209 rewriter.setInsertionPointToStart(outlinedFuncBody);
210 if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
211 IRMapping bvm;
212 repl = rewriter.clone(*cst, bvm)->getResult(0);
213 }
214 }
215 orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
216 return outlinedFunc->isProperAncestor(opOperand.getOwner());
217 });
218 }
219
220 return outlinedFunc;
221}
222
223LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
224 func::FuncOp *thenFn, StringRef thenFnName,
225 func::FuncOp *elseFn, StringRef elseFnName) {
226 IRRewriter rewriter(b);
227 Location loc = ifOp.getLoc();
228 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
229 if (thenFn && !ifOp.getThenRegion().empty()) {
230 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
231 rewriter, loc, ifOp.getThenRegion(), thenFnName);
232 if (failed(Result: outlinedFuncOpOrFailure))
233 return failure();
234 *thenFn = *outlinedFuncOpOrFailure;
235 }
236 if (elseFn && !ifOp.getElseRegion().empty()) {
237 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
238 rewriter, loc, ifOp.getElseRegion(), elseFnName);
239 if (failed(Result: outlinedFuncOpOrFailure))
240 return failure();
241 *elseFn = *outlinedFuncOpOrFailure;
242 }
243 return success();
244}
245
246bool mlir::getInnermostParallelLoops(Operation *rootOp,
247 SmallVectorImpl<scf::ParallelOp> &result) {
248 assert(rootOp != nullptr && "Root operation must not be a nullptr.");
249 bool rootEnclosesPloops = false;
250 for (Region &region : rootOp->getRegions()) {
251 for (Block &block : region.getBlocks()) {
252 for (Operation &op : block) {
253 bool enclosesPloops = getInnermostParallelLoops(&op, result);
254 rootEnclosesPloops |= enclosesPloops;
255 if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
256 rootEnclosesPloops = true;
257
258 // Collect parallel loop if it is an innermost one.
259 if (!enclosesPloops)
260 result.push_back(ploop);
261 }
262 }
263 }
264 }
265 return rootEnclosesPloops;
266}
267
268// Build the IR that performs ceil division of a positive value by a constant:
269// ceildiv(a, B) = divis(a + (B-1), B)
270// where divis is rounding-to-zero division.
271static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
272 int64_t divisor) {
273 assert(divisor > 0 && "expected positive divisor");
274 assert(dividend.getType().isIntOrIndex() &&
275 "expected integer or index-typed value");
276
277 Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
278 loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
279 Value divisorCst = builder.create<arith::ConstantOp>(
280 loc, builder.getIntegerAttr(dividend.getType(), divisor));
281 Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
282 return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
283}
284
285// Build the IR that performs ceil division of a positive value by another
286// positive value:
287// ceildiv(a, b) = divis(a + (b - 1), b)
288// where divis is rounding-to-zero division.
289static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
290 Value divisor) {
291 assert(dividend.getType().isIntOrIndex() &&
292 "expected integer or index-typed value");
293 Value cstOne = builder.create<arith::ConstantOp>(
294 loc, builder.getOneAttr(dividend.getType()));
295 Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
296 Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
297 return builder.create<arith::DivUIOp>(loc, sum, divisor);
298}
299
300/// Returns the trip count of `forOp` if its' low bound, high bound and step are
301/// constants, or optional otherwise. Trip count is computed as
302/// ceilDiv(highBound - lowBound, step).
303static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
304 std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
305 std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
306 std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
307 if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
308 return {};
309
310 // Constant loop bounds computation.
311 int64_t lbCst = lbCstOp.value();
312 int64_t ubCst = ubCstOp.value();
313 int64_t stepCst = stepCstOp.value();
314 assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
315 "expected positive loop bounds and step");
316 return llvm::divideCeilSigned(Numerator: ubCst - lbCst, Denominator: stepCst);
317}
318
319/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
320/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
321/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
322/// unrolled iteration using annotateFn.
323static void generateUnrolledLoop(
324 Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
325 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
326 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
327 ValueRange iterArgs, ValueRange yieldedValues) {
328 // Builder to insert unrolled bodies just before the terminator of the body of
329 // 'forOp'.
330 auto builder = OpBuilder::atBlockTerminator(block: loopBodyBlock);
331
332 constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
333 if (!annotateFn)
334 annotateFn = defaultAnnotateFn;
335
336 // Keep a pointer to the last non-terminator operation in the original block
337 // so that we know what to clone (since we are doing this in-place).
338 Block::iterator srcBlockEnd = std::prev(x: loopBodyBlock->end(), n: 2);
339
340 // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
341 SmallVector<Value, 4> lastYielded(yieldedValues);
342
343 for (unsigned i = 1; i < unrollFactor; i++) {
344 IRMapping operandMap;
345
346 // Prepare operand map.
347 operandMap.map(from&: iterArgs, to&: lastYielded);
348
349 // If the induction variable is used, create a remapping to the value for
350 // this unrolled instance.
351 if (!forOpIV.use_empty()) {
352 Value ivUnroll = ivRemapFn(i, forOpIV, builder);
353 operandMap.map(from: forOpIV, to: ivUnroll);
354 }
355
356 // Clone the original body of 'forOp'.
357 for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++) {
358 Operation *clonedOp = builder.clone(op&: *it, mapper&: operandMap);
359 annotateFn(i, clonedOp, builder);
360 }
361
362 // Update yielded values.
363 for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
364 lastYielded[i] = operandMap.lookupOrDefault(from: yieldedValues[i]);
365 }
366
367 // Make sure we annotate the Ops in the original body. We do this last so that
368 // any annotations are not copied into the cloned Ops above.
369 for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++)
370 annotateFn(0, &*it, builder);
371
372 // Update operands of the yield statement.
373 loopBodyBlock->getTerminator()->setOperands(lastYielded);
374}
375
376/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
377/// epilogue loop, if the loop is unrolled.
378FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
379 scf::ForOp forOp, uint64_t unrollFactor,
380 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
381 assert(unrollFactor > 0 && "expected positive unroll factor");
382
383 // Return if the loop body is empty.
384 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
385 return UnrolledLoopInfo{forOp, std::nullopt};
386
387 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
388 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
389 OpBuilder boundsBuilder(forOp);
390 IRRewriter rewriter(forOp.getContext());
391 auto loc = forOp.getLoc();
392 Value step = forOp.getStep();
393 Value upperBoundUnrolled;
394 Value stepUnrolled;
395 bool generateEpilogueLoop = true;
396
397 std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
398 if (constTripCount) {
399 // Constant loop bounds computation.
400 int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
401 int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
402 int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
403 if (unrollFactor == 1) {
404 if (*constTripCount == 1 &&
405 failed(forOp.promoteIfSingleIteration(rewriter)))
406 return failure();
407 return UnrolledLoopInfo{forOp, std::nullopt};
408 }
409
410 int64_t tripCountEvenMultiple =
411 *constTripCount - (*constTripCount % unrollFactor);
412 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
413 int64_t stepUnrolledCst = stepCst * unrollFactor;
414
415 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
416 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
417 if (generateEpilogueLoop)
418 upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
419 loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
420 upperBoundUnrolledCst));
421 else
422 upperBoundUnrolled = forOp.getUpperBound();
423
424 // Create constant for 'stepUnrolled'.
425 stepUnrolled = stepCst == stepUnrolledCst
426 ? step
427 : boundsBuilder.create<arith::ConstantOp>(
428 loc, boundsBuilder.getIntegerAttr(
429 step.getType(), stepUnrolledCst));
430 } else {
431 // Dynamic loop bounds computation.
432 // TODO: Add dynamic asserts for negative lb/ub/step, or
433 // consider using ceilDiv from AffineApplyExpander.
434 auto lowerBound = forOp.getLowerBound();
435 auto upperBound = forOp.getUpperBound();
436 Value diff =
437 boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
438 Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
439 Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
440 loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
441 Value tripCountRem =
442 boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
443 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
444 Value tripCountEvenMultiple =
445 boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
446 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
447 upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
448 loc, lowerBound,
449 boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
450 // Scale 'step' by 'unrollFactor'.
451 stepUnrolled =
452 boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
453 }
454
455 UnrolledLoopInfo resultLoops;
456
457 // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
458 if (generateEpilogueLoop) {
459 OpBuilder epilogueBuilder(forOp->getContext());
460 epilogueBuilder.setInsertionPointAfter(forOp);
461 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
462 epilogueForOp.setLowerBound(upperBoundUnrolled);
463
464 // Update uses of loop results.
465 auto results = forOp.getResults();
466 auto epilogueResults = epilogueForOp.getResults();
467
468 for (auto e : llvm::zip(results, epilogueResults)) {
469 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
470 }
471 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
472 epilogueForOp.getInitArgs().size(), results);
473 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
474 resultLoops.epilogueLoopOp = epilogueForOp;
475 }
476
477 // Create unrolled loop.
478 forOp.setUpperBound(upperBoundUnrolled);
479 forOp.setStep(stepUnrolled);
480
481 auto iterArgs = ValueRange(forOp.getRegionIterArgs());
482 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
483
484 generateUnrolledLoop(
485 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
486 [&](unsigned i, Value iv, OpBuilder b) {
487 // iv' = iv + step * i;
488 auto stride = b.create<arith::MulIOp>(
489 loc, step,
490 b.create<arith::ConstantOp>(loc,
491 b.getIntegerAttr(iv.getType(), i)));
492 return b.create<arith::AddIOp>(loc, iv, stride);
493 },
494 annotateFn, iterArgs, yieldedValues);
495 // Promote the loop body up if this has turned into a single iteration loop.
496 if (forOp.promoteIfSingleIteration(rewriter).failed())
497 resultLoops.mainLoopOp = forOp;
498 return resultLoops;
499}
500
501/// Unrolls this loop completely.
502LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
503 IRRewriter rewriter(forOp.getContext());
504 std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
505 if (!mayBeConstantTripCount.has_value())
506 return failure();
507 uint64_t tripCount = *mayBeConstantTripCount;
508 if (tripCount == 0)
509 return success();
510 if (tripCount == 1)
511 return forOp.promoteIfSingleIteration(rewriter);
512 return loopUnrollByFactor(forOp, tripCount);
513}
514
515/// Check if bounds of all inner loops are defined outside of `forOp`
516/// and return false if not.
517static bool areInnerBoundsInvariant(scf::ForOp forOp) {
518 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
519 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
520 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
521 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
522 return WalkResult::interrupt();
523
524 return WalkResult::advance();
525 });
526 return !walkResult.wasInterrupted();
527}
528
529/// Unrolls and jams this loop by the specified factor.
530LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
531 uint64_t unrollJamFactor) {
532 assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
533
534 if (unrollJamFactor == 1)
535 return success();
536
537 // If any control operand of any inner loop of `forOp` is defined within
538 // `forOp`, no unroll jam.
539 if (!areInnerBoundsInvariant(forOp)) {
540 LDBG("failed to unroll and jam: inner bounds are not invariant");
541 return failure();
542 }
543
544 // Currently, for operations with results are not supported.
545 if (forOp->getNumResults() > 0) {
546 LDBG("failed to unroll and jam: unsupported loop with results");
547 return failure();
548 }
549
550 // Currently, only constant trip count that divided by the unroll factor is
551 // supported.
552 std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
553 if (!tripCount.has_value()) {
554 // If the trip count is dynamic, do not unroll & jam.
555 LDBG("failed to unroll and jam: trip count could not be determined");
556 return failure();
557 }
558 if (unrollJamFactor > *tripCount) {
559 LDBG("unroll and jam factor is greater than trip count, set factor to trip "
560 "count");
561 unrollJamFactor = *tripCount;
562 } else if (*tripCount % unrollJamFactor != 0) {
563 LDBG("failed to unroll and jam: unsupported trip count that is not a "
564 "multiple of unroll jam factor");
565 return failure();
566 }
567
568 // Nothing in the loop body other than the terminator.
569 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
570 return success();
571
572 // Gather all sub-blocks to jam upon the loop being unrolled.
573 JamBlockGatherer<scf::ForOp> jbg;
574 jbg.walk(forOp);
575 auto &subBlocks = jbg.subBlocks;
576
577 // Collect inner loops.
578 SmallVector<scf::ForOp> innerLoops;
579 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
580
581 // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
582 // iteration. There are (`unrollJamFactor` - 1) iterations.
583 SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
584
585 // For any loop with iter_args, replace it with a new loop that has
586 // `unrollJamFactor` copies of its iterOperands, iter_args and yield
587 // operands.
588 SmallVector<scf::ForOp> newInnerLoops;
589 IRRewriter rewriter(forOp.getContext());
590 for (scf::ForOp oldForOp : innerLoops) {
591 SmallVector<Value> dupIterOperands, dupYieldOperands;
592 ValueRange oldIterOperands = oldForOp.getInits();
593 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
594 ValueRange oldYieldOperands =
595 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
596 // Get additional iterOperands, iterArgs, and yield operands. We will
597 // fix iterOperands and yield operands after cloning of sub-blocks.
598 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
599 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
600 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
601 }
602 // Create a new loop with additional iterOperands, iter_args and yield
603 // operands. This new loop will take the loop body of the original loop.
604 bool forOpReplaced = oldForOp == forOp;
605 scf::ForOp newForOp =
606 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
607 rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
608 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
609 return dupYieldOperands;
610 }));
611 newInnerLoops.push_back(newForOp);
612 // `forOp` has been replaced with a new loop.
613 if (forOpReplaced)
614 forOp = newForOp;
615 // Update `operandMaps` for `newForOp` iterArgs and results.
616 ValueRange newIterArgs = newForOp.getRegionIterArgs();
617 unsigned oldNumIterArgs = oldIterArgs.size();
618 ValueRange newResults = newForOp.getResults();
619 unsigned oldNumResults = newResults.size() / unrollJamFactor;
620 assert(oldNumIterArgs == oldNumResults &&
621 "oldNumIterArgs must be the same as oldNumResults");
622 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
623 for (unsigned j = 0; j < oldNumIterArgs; ++j) {
624 // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
625 // results. Update `operandMaps[i - 1]` to map old iterArgs and results
626 // to those in the `i`th new set.
627 operandMaps[i - 1].map(newIterArgs[j],
628 newIterArgs[i * oldNumIterArgs + j]);
629 operandMaps[i - 1].map(newResults[j],
630 newResults[i * oldNumResults + j]);
631 }
632 }
633 }
634
635 // Scale the step of loop being unroll-jammed by the unroll-jam factor.
636 rewriter.setInsertionPoint(forOp);
637 int64_t step = forOp.getConstantStep()->getSExtValue();
638 auto newStep = rewriter.createOrFold<arith::MulIOp>(
639 forOp.getLoc(), forOp.getStep(),
640 rewriter.createOrFold<arith::ConstantOp>(
641 forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
642 forOp.setStep(newStep);
643 auto forOpIV = forOp.getInductionVar();
644
645 // Unroll and jam (appends unrollJamFactor - 1 additional copies).
646 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
647 for (auto &subBlock : subBlocks) {
648 // Builder to insert unroll-jammed bodies. Insert right at the end of
649 // sub-block.
650 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
651
652 // If the induction variable is used, create a remapping to the value for
653 // this unrolled instance.
654 if (!forOpIV.use_empty()) {
655 // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
656 auto ivTag = builder.createOrFold<arith::ConstantOp>(
657 forOp.getLoc(), builder.getIndexAttr(step * i));
658 auto ivUnroll =
659 builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
660 operandMaps[i - 1].map(forOpIV, ivUnroll);
661 }
662 // Clone the sub-block being unroll-jammed.
663 for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
664 builder.clone(*it, operandMaps[i - 1]);
665 }
666 // Fix iterOperands and yield op operands of newly created loops.
667 for (auto newForOp : newInnerLoops) {
668 unsigned oldNumIterOperands =
669 newForOp.getNumRegionIterArgs() / unrollJamFactor;
670 unsigned numControlOperands = newForOp.getNumControlOperands();
671 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
672 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
673 assert(oldNumIterOperands == oldNumYieldOperands &&
674 "oldNumIterOperands must be the same as oldNumYieldOperands");
675 for (unsigned j = 0; j < oldNumIterOperands; ++j) {
676 // The `i`th duplication of an old iterOperand or yield op operand
677 // needs to be replaced with a mapped value from `operandMaps[i - 1]`
678 // if such mapped value exists.
679 newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
680 operandMaps[i - 1].lookupOrDefault(
681 newForOp.getOperand(numControlOperands + j)));
682 yieldOp.setOperand(
683 i * oldNumYieldOperands + j,
684 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
685 }
686 }
687 }
688
689 // Promote the loop body up if this has turned into a single iteration loop.
690 (void)forOp.promoteIfSingleIteration(rewriter);
691 return success();
692}
693
694Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
695 OpFoldResult lb, OpFoldResult ub,
696 OpFoldResult step) {
697 Range normalizedLoopBounds;
698 normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
699 normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
700 AffineExpr s0, s1, s2;
701 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
702 AffineExpr e = (s1 - s0).ceilDiv(other: s2);
703 normalizedLoopBounds.size =
704 affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: e, operands: {lb, ub, step});
705 return normalizedLoopBounds;
706}
707
708Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
709 OpFoldResult lb, OpFoldResult ub,
710 OpFoldResult step) {
711 if (getType(ofr: lb).isIndex()) {
712 return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
713 }
714 // For non-index types, generate `arith` instructions
715 // Check if the loop is already known to have a constant zero lower bound or
716 // a constant one step.
717 bool isZeroBased = false;
718 if (auto lbCst = getConstantIntValue(ofr: lb))
719 isZeroBased = lbCst.value() == 0;
720
721 bool isStepOne = false;
722 if (auto stepCst = getConstantIntValue(ofr: step))
723 isStepOne = stepCst.value() == 1;
724
725 Type rangeType = getType(ofr: lb);
726 assert(rangeType == getType(ub) && rangeType == getType(step) &&
727 "expected matching types");
728
729 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
730 // assuming the step is strictly positive. Update the bounds and the step
731 // of the loop to go from 0 to the number of iterations, if necessary.
732 if (isZeroBased && isStepOne)
733 return {.offset: lb, .size: ub, .stride: step};
734
735 OpFoldResult diff = ub;
736 if (!isZeroBased) {
737 diff = rewriter.createOrFold<arith::SubIOp>(
738 loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
739 getValueOrCreateConstantIntOp(rewriter, loc, lb));
740 }
741 OpFoldResult newUpperBound = diff;
742 if (!isStepOne) {
743 newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
744 loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
745 getValueOrCreateConstantIntOp(rewriter, loc, step));
746 }
747
748 OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
749 OpFoldResult newStep = rewriter.getOneAttr(rangeType);
750
751 return {.offset: newLowerBound, .size: newUpperBound, .stride: newStep};
752}
753
754static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
755 Location loc,
756 Value normalizedIv,
757 OpFoldResult origLb,
758 OpFoldResult origStep) {
759 AffineExpr d0, s0, s1;
760 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
761 bindDims(ctx: rewriter.getContext(), exprs&: d0);
762 AffineExpr e = d0 * s1 + s0;
763 OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
764 b&: rewriter, loc, expr: e, operands: ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
765 Value denormalizedIvVal =
766 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: denormalizedIv);
767 SmallPtrSet<Operation *, 1> preservedUses;
768 // If an `affine.apply` operation is generated for denormalization, the use
769 // of `origLb` in those ops must not be replaced. These arent not generated
770 // when `origLb == 0` and `origStep == 1`.
771 if (!isZeroInteger(v: origLb) || !isOneInteger(v: origStep)) {
772 if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
773 preservedUses.insert(Ptr: preservedUse);
774 }
775 }
776 rewriter.replaceAllUsesExcept(from: normalizedIv, to: denormalizedIvVal, preservedUsers: preservedUses);
777}
778
779void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
780 Value normalizedIv, OpFoldResult origLb,
781 OpFoldResult origStep) {
782 if (getType(ofr: origLb).isIndex()) {
783 return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
784 origLb, origStep);
785 }
786 Value denormalizedIv;
787 SmallPtrSet<Operation *, 2> preserve;
788 bool isStepOne = isOneInteger(v: origStep);
789 bool isZeroBased = isZeroInteger(v: origLb);
790
791 Value scaled = normalizedIv;
792 if (!isStepOne) {
793 Value origStepValue =
794 getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: origStep);
795 scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue);
796 preserve.insert(Ptr: scaled.getDefiningOp());
797 }
798 denormalizedIv = scaled;
799 if (!isZeroBased) {
800 Value origLbValue = getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: origLb);
801 denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue);
802 preserve.insert(Ptr: denormalizedIv.getDefiningOp());
803 }
804
805 rewriter.replaceAllUsesExcept(from: normalizedIv, to: denormalizedIv, preservedUsers: preserve);
806}
807
808static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
809 ArrayRef<OpFoldResult> values) {
810 assert(!values.empty() && "unexecpted empty array");
811 AffineExpr s0, s1;
812 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
813 AffineExpr mul = s0 * s1;
814 OpFoldResult products = rewriter.getIndexAttr(1);
815 for (auto v : values) {
816 products = affine::makeComposedFoldedAffineApply(
817 b&: rewriter, loc, expr: mul, operands: ArrayRef<OpFoldResult>{products, v});
818 }
819 return products;
820}
821
822/// Helper function to multiply a sequence of values.
823static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
824 ArrayRef<Value> values) {
825 assert(!values.empty() && "unexpected empty list");
826 if (getType(ofr: values.front()).isIndex()) {
827 SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
828 OpFoldResult product = getProductOfIndexes(rewriter, loc, values: ofrs);
829 return getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: product);
830 }
831 std::optional<Value> productOf;
832 for (auto v : values) {
833 auto vOne = getConstantIntValue(ofr: v);
834 if (vOne && vOne.value() == 1)
835 continue;
836 if (productOf)
837 productOf =
838 rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
839 else
840 productOf = v;
841 }
842 if (!productOf) {
843 productOf = rewriter
844 .create<arith::ConstantOp>(
845 loc, rewriter.getOneAttr(getType(values.front())))
846 .getResult();
847 }
848 return productOf.value();
849}
850
851/// For each original loop, the value of the
852/// induction variable can be obtained by dividing the induction variable of
853/// the linearized loop by the total number of iterations of the loops nested
854/// in it modulo the number of iterations in this loop (remove the values
855/// related to the outer loops):
856/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
857/// Compute these iteratively from the innermost loop by creating a "running
858/// quotient" of division by the range.
859static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
860delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
861 Value linearizedIv, ArrayRef<Value> ubs) {
862
863 if (linearizedIv.getType().isIndex()) {
864 Operation *delinearizedOp =
865 rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
866 ubs);
867 auto resultVals = llvm::map_to_vector(
868 C: delinearizedOp->getResults(), F: [](OpResult r) -> Value { return r; });
869 return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
870 }
871
872 SmallVector<Value> delinearizedIvs(ubs.size());
873 SmallPtrSet<Operation *, 2> preservedUsers;
874
875 llvm::BitVector isUbOne(ubs.size());
876 for (auto [index, ub] : llvm::enumerate(First&: ubs)) {
877 auto ubCst = getConstantIntValue(ofr: ub);
878 if (ubCst && ubCst.value() == 1)
879 isUbOne.set(index);
880 }
881
882 // Prune the lead ubs that are all ones.
883 unsigned numLeadingOneUbs = 0;
884 for (auto [index, ub] : llvm::enumerate(First&: ubs)) {
885 if (!isUbOne.test(Idx: index)) {
886 break;
887 }
888 delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
889 loc, rewriter.getZeroAttr(ub.getType()));
890 numLeadingOneUbs++;
891 }
892
893 Value previous = linearizedIv;
894 for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
895 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
896 if (i != numLeadingOneUbs && !isUbOne.test(Idx: idx + 1)) {
897 previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
898 preservedUsers.insert(Ptr: previous.getDefiningOp());
899 }
900 Value iv = previous;
901 if (i != e - 1) {
902 if (!isUbOne.test(Idx: idx)) {
903 iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
904 preservedUsers.insert(Ptr: iv.getDefiningOp());
905 } else {
906 iv = rewriter.create<arith::ConstantOp>(
907 loc, rewriter.getZeroAttr(ubs[idx].getType()));
908 }
909 }
910 delinearizedIvs[idx] = iv;
911 }
912 return {delinearizedIvs, preservedUsers};
913}
914
915LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
916 MutableArrayRef<scf::ForOp> loops) {
917 if (loops.size() < 2)
918 return failure();
919
920 scf::ForOp innermost = loops.back();
921 scf::ForOp outermost = loops.front();
922
923 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
924 // allows the following code to assume upperBound is the number of iterations.
925 for (auto loop : loops) {
926 OpBuilder::InsertionGuard g(rewriter);
927 rewriter.setInsertionPoint(outermost);
928 Value lb = loop.getLowerBound();
929 Value ub = loop.getUpperBound();
930 Value step = loop.getStep();
931 auto newLoopRange =
932 emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
933
934 rewriter.modifyOpInPlace(loop, [&]() {
935 loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
936 newLoopRange.offset));
937 loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
938 newLoopRange.size));
939 loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
940 newLoopRange.stride));
941 });
942 rewriter.setInsertionPointToStart(innermost.getBody());
943 denormalizeInductionVariable(rewriter, loop.getLoc(),
944 loop.getInductionVar(), lb, step);
945 }
946
947 // 2. Emit code computing the upper bound of the coalesced loop as product
948 // of the number of iterations of all loops.
949 OpBuilder::InsertionGuard g(rewriter);
950 rewriter.setInsertionPoint(outermost);
951 Location loc = outermost.getLoc();
952 SmallVector<Value> upperBounds = llvm::map_to_vector(
953 loops, [](auto loop) { return loop.getUpperBound(); });
954 Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, values: upperBounds);
955 outermost.setUpperBound(upperBound);
956
957 rewriter.setInsertionPointToStart(innermost.getBody());
958 auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
959 rewriter, loc, outermost.getInductionVar(), upperBounds);
960 rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
961 preservedUsers);
962
963 for (int i = loops.size() - 1; i > 0; --i) {
964 auto outerLoop = loops[i - 1];
965 auto innerLoop = loops[i];
966
967 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
968 auto yieldedVals = llvm::to_vector(Range: innerTerminator->getOperands());
969 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
970 for (Value &yieldedVal : yieldedVals) {
971 // The yielded value may be an iteration argument of the inner loop
972 // which is about to be inlined.
973 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
974 if (iter != innerLoop.getRegionIterArgs().end()) {
975 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
976 // `outerLoop` iter args identical to the `innerLoop` init args.
977 assert(iterArgIndex < innerLoop.getInitArgs().size());
978 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
979 }
980 }
981 rewriter.eraseOp(op: innerTerminator);
982
983 SmallVector<Value> innerBlockArgs;
984 innerBlockArgs.push_back(Elt: delinearizeIvs[i]);
985 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
986 rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
987 Block::iterator(innerLoop), innerBlockArgs);
988 rewriter.replaceOp(innerLoop, yieldedVals);
989 }
990 return success();
991}
992
993LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
994 if (loops.empty()) {
995 return failure();
996 }
997 IRRewriter rewriter(loops.front().getContext());
998 return coalesceLoops(rewriter, loops);
999}
1000
1001LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
1002 LogicalResult result(failure());
1003 SmallVector<scf::ForOp> loops;
1004 getPerfectlyNestedLoops(loops, op);
1005
1006 // Look for a band of loops that can be coalesced, i.e. perfectly nested
1007 // loops with bounds defined above some loop.
1008
1009 // 1. For each loop, find above which parent loop its bounds operands are
1010 // defined.
1011 SmallVector<unsigned> operandsDefinedAbove(loops.size());
1012 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1013 operandsDefinedAbove[i] = i;
1014 for (unsigned j = 0; j < i; ++j) {
1015 SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1016 loops[i].getUpperBound(),
1017 loops[i].getStep()};
1018 if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1019 operandsDefinedAbove[i] = j;
1020 break;
1021 }
1022 }
1023 }
1024
1025 // 2. For each inner loop check that the iter_args for the immediately outer
1026 // loop are the init for the immediately inner loop and that the yields of the
1027 // return of the inner loop is the yield for the immediately outer loop. Keep
1028 // track of where the chain starts from for each loop.
1029 SmallVector<unsigned> iterArgChainStart(loops.size());
1030 iterArgChainStart[0] = 0;
1031 for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1032 // By default set the start of the chain to itself.
1033 iterArgChainStart[i] = i;
1034 auto outerloop = loops[i - 1];
1035 auto innerLoop = loops[i];
1036 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1037 continue;
1038 }
1039 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1040 continue;
1041 }
1042 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1043 if (!llvm::equal(outerloopTerminator->getOperands(),
1044 innerLoop.getResults())) {
1045 continue;
1046 }
1047 iterArgChainStart[i] = iterArgChainStart[i - 1];
1048 }
1049
1050 // 3. Identify bands of loops such that the operands of all of them are
1051 // defined above the first loop in the band. Traverse the nest bottom-up
1052 // so that modifications don't invalidate the inner loops.
1053 for (unsigned end = loops.size(); end > 0; --end) {
1054 unsigned start = 0;
1055 for (; start < end - 1; ++start) {
1056 auto maxPos =
1057 *std::max_element(first: std::next(x: operandsDefinedAbove.begin(), n: start),
1058 last: std::next(x: operandsDefinedAbove.begin(), n: end));
1059 if (maxPos > start)
1060 continue;
1061 if (iterArgChainStart[end - 1] > start)
1062 continue;
1063 auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1064 if (succeeded(coalesceLoops(band)))
1065 result = success();
1066 break;
1067 }
1068 // If a band was found and transformed, keep looking at the loops above
1069 // the outermost transformed loop.
1070 if (start != end - 1)
1071 end = start + 1;
1072 }
1073 return result;
1074}
1075
1076void mlir::collapseParallelLoops(
1077 RewriterBase &rewriter, scf::ParallelOp loops,
1078 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1079 OpBuilder::InsertionGuard g(rewriter);
1080 rewriter.setInsertionPoint(loops);
1081 Location loc = loops.getLoc();
1082
1083 // Presort combined dimensions.
1084 auto sortedDimensions = llvm::to_vector<3>(Range&: combinedDimensions);
1085 for (auto &dims : sortedDimensions)
1086 llvm::sort(C&: dims);
1087
1088 // Normalize ParallelOp's iteration pattern.
1089 SmallVector<Value, 3> normalizedUpperBounds;
1090 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1091 OpBuilder::InsertionGuard g2(rewriter);
1092 rewriter.setInsertionPoint(loops);
1093 Value lb = loops.getLowerBound()[i];
1094 Value ub = loops.getUpperBound()[i];
1095 Value step = loops.getStep()[i];
1096 auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1097 normalizedUpperBounds.push_back(Elt: getValueOrCreateConstantIntOp(
1098 rewriter, loops.getLoc(), newLoopRange.size));
1099
1100 rewriter.setInsertionPointToStart(loops.getBody());
1101 denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1102 step);
1103 }
1104
1105 // Combine iteration spaces.
1106 SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1107 auto cst0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1108 auto cst1 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
1109 for (auto &sortedDimension : sortedDimensions) {
1110 Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
1111 for (auto idx : sortedDimension) {
1112 newUpperBound = rewriter.create<arith::MulIOp>(
1113 loc, newUpperBound, normalizedUpperBounds[idx]);
1114 }
1115 lowerBounds.push_back(Elt: cst0);
1116 steps.push_back(Elt: cst1);
1117 upperBounds.push_back(Elt: newUpperBound);
1118 }
1119
1120 // Create new ParallelLoop with conversions to the original induction values.
1121 // The loop below uses divisions to get the relevant range of values in the
1122 // new induction value that represent each range of the original induction
1123 // value. The remainders then determine based on that range, which iteration
1124 // of the original induction value this represents. This is a normalized value
1125 // that is un-normalized already by the previous logic.
1126 auto newPloop = rewriter.create<scf::ParallelOp>(
1127 loc, lowerBounds, upperBounds, steps,
1128 [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1129 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1130 Value previous = ploopIVs[i];
1131 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1132 // Iterate over all except the last induction value.
1133 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1134 unsigned idx = combinedDimensions[i][j];
1135
1136 // Determine the current induction value's current loop iteration
1137 Value iv = insideBuilder.create<arith::RemSIOp>(
1138 loc, previous, normalizedUpperBounds[idx]);
1139 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1140 loops.getRegion());
1141
1142 // Remove the effect of the current induction value to prepare for
1143 // the next value.
1144 previous = insideBuilder.create<arith::DivSIOp>(
1145 loc, previous, normalizedUpperBounds[idx]);
1146 }
1147
1148 // The final induction value is just the remaining value.
1149 unsigned idx = combinedDimensions[i][0];
1150 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1151 previous, loops.getRegion());
1152 }
1153 });
1154
1155 // Replace the old loop with the new loop.
1156 loops.getBody()->back().erase();
1157 newPloop.getBody()->getOperations().splice(
1158 Block::iterator(newPloop.getBody()->back()),
1159 loops.getBody()->getOperations());
1160 loops.erase();
1161}
1162
1163// Hoist the ops within `outer` that appear before `inner`.
1164// Such ops include the ops that have been introduced by parametric tiling.
1165// Ops that come from triangular loops (i.e. that belong to the program slice
1166// rooted at `outer`) and ops that have side effects cannot be hoisted.
1167// Return failure when any op fails to hoist.
1168static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1169 SetVector<Operation *> forwardSlice;
1170 ForwardSliceOptions options;
1171 options.filter = [&inner](Operation *op) {
1172 return op != inner.getOperation();
1173 };
1174 getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1175 LogicalResult status = success();
1176 SmallVector<Operation *, 8> toHoist;
1177 for (auto &op : outer.getBody()->without_terminator()) {
1178 // Stop when encountering the inner loop.
1179 if (&op == inner.getOperation())
1180 break;
1181 // Skip over non-hoistable ops.
1182 if (forwardSlice.count(&op) > 0) {
1183 status = failure();
1184 continue;
1185 }
1186 // Skip intermediate scf::ForOp, these are not considered a failure.
1187 if (isa<scf::ForOp>(op))
1188 continue;
1189 // Skip other ops with regions.
1190 if (op.getNumRegions() > 0) {
1191 status = failure();
1192 continue;
1193 }
1194 // Skip if op has side effects.
1195 // TODO: loads to immutable memory regions are ok.
1196 if (!isMemoryEffectFree(&op)) {
1197 status = failure();
1198 continue;
1199 }
1200 toHoist.push_back(&op);
1201 }
1202 auto *outerForOp = outer.getOperation();
1203 for (auto *op : toHoist)
1204 op->moveBefore(outerForOp);
1205 return status;
1206}
1207
1208// Traverse the interTile and intraTile loops and try to hoist ops such that
1209// bands of perfectly nested loops are isolated.
1210// Return failure if either perfect interTile or perfect intraTile bands cannot
1211// be formed.
1212static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1213 LogicalResult status = success();
1214 const Loops &interTile = tileLoops.first;
1215 const Loops &intraTile = tileLoops.second;
1216 auto size = interTile.size();
1217 assert(size == intraTile.size());
1218 if (size <= 1)
1219 return success();
1220 for (unsigned s = 1; s < size; ++s)
1221 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1222 : failure();
1223 for (unsigned s = 1; s < size; ++s)
1224 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1225 : failure();
1226 return status;
1227}
1228
1229/// Collect perfectly nested loops starting from `rootForOps`. Loops are
1230/// perfectly nested if each loop is the first and only non-terminator operation
1231/// in the parent loop. Collect at most `maxLoops` loops and append them to
1232/// `forOps`.
1233template <typename T>
1234static void getPerfectlyNestedLoopsImpl(
1235 SmallVectorImpl<T> &forOps, T rootForOp,
1236 unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1237 for (unsigned i = 0; i < maxLoops; ++i) {
1238 forOps.push_back(rootForOp);
1239 Block &body = rootForOp.getRegion().front();
1240 if (body.begin() != std::prev(x: body.end(), n: 2))
1241 return;
1242
1243 rootForOp = dyn_cast<T>(&body.front());
1244 if (!rootForOp)
1245 return;
1246 }
1247}
1248
1249static Loops stripmineSink(scf::ForOp forOp, Value factor,
1250 ArrayRef<scf::ForOp> targets) {
1251 auto originalStep = forOp.getStep();
1252 auto iv = forOp.getInductionVar();
1253
1254 OpBuilder b(forOp);
1255 forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
1256
1257 Loops innerLoops;
1258 for (auto t : targets) {
1259 // Save information for splicing ops out of t when done
1260 auto begin = t.getBody()->begin();
1261 auto nOps = t.getBody()->getOperations().size();
1262
1263 // Insert newForOp before the terminator of `t`.
1264 auto b = OpBuilder::atBlockTerminator((t.getBody()));
1265 Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
1266 Value ub =
1267 b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
1268
1269 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1270 auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
1271 newForOp.getBody()->getOperations().splice(
1272 newForOp.getBody()->getOperations().begin(),
1273 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1274 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1275 newForOp.getRegion());
1276
1277 innerLoops.push_back(newForOp);
1278 }
1279
1280 return innerLoops;
1281}
1282
1283// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1284// Returns the new for operation, nested immediately under `target`.
1285template <typename SizeType>
1286static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1287 scf::ForOp target) {
1288 // TODO: Use cheap structural assertions that targets are nested under
1289 // forOp and that targets are not nested under each other when DominanceInfo
1290 // exposes the capability. It seems overkill to construct a whole function
1291 // dominance tree at this point.
1292 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1293 assert(res.size() == 1 && "Expected 1 inner forOp");
1294 return res[0];
1295}
1296
1297SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
1298 ArrayRef<Value> sizes,
1299 ArrayRef<scf::ForOp> targets) {
1300 SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
1301 SmallVector<scf::ForOp, 8> currentTargets(targets);
1302 for (auto it : llvm::zip(forOps, sizes)) {
1303 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1304 res.push_back(step);
1305 currentTargets = step;
1306 }
1307 return res;
1308}
1309
1310Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
1311 scf::ForOp target) {
1312 SmallVector<scf::ForOp, 8> res;
1313 for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1314 res.push_back(llvm::getSingleElement(loops));
1315 return res;
1316}
1317
1318Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1319 // Collect perfectly nested loops. If more size values provided than nested
1320 // loops available, truncate `sizes`.
1321 SmallVector<scf::ForOp, 4> forOps;
1322 forOps.reserve(sizes.size());
1323 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1324 if (forOps.size() < sizes.size())
1325 sizes = sizes.take_front(N: forOps.size());
1326
1327 return ::tile(forOps, sizes, forOps.back());
1328}
1329
1330void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
1331 scf::ForOp root) {
1332 getPerfectlyNestedLoopsImpl(nestedLoops, root);
1333}
1334
1335TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
1336 ArrayRef<int64_t> sizes) {
1337 // Collect perfectly nested loops. If more size values provided than nested
1338 // loops available, truncate `sizes`.
1339 SmallVector<scf::ForOp, 4> forOps;
1340 forOps.reserve(sizes.size());
1341 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1342 if (forOps.size() < sizes.size())
1343 sizes = sizes.take_front(N: forOps.size());
1344
1345 // Compute the tile sizes such that i-th outer loop executes size[i]
1346 // iterations. Given that the loop current executes
1347 // numIterations = ceildiv((upperBound - lowerBound), step)
1348 // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1349 SmallVector<Value, 4> tileSizes;
1350 tileSizes.reserve(N: sizes.size());
1351 for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1352 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1353
1354 auto forOp = forOps[i];
1355 OpBuilder builder(forOp);
1356 auto loc = forOp.getLoc();
1357 Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
1358 forOp.getLowerBound());
1359 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1360 Value iterationsPerBlock =
1361 ceilDivPositive(builder, loc, numIterations, sizes[i]);
1362 tileSizes.push_back(Elt: iterationsPerBlock);
1363 }
1364
1365 // Call parametric tiling with the given sizes.
1366 auto intraTile = tile(forOps, tileSizes, forOps.back());
1367 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1368
1369 // TODO: for now we just ignore the result of band isolation.
1370 // In the future, mapping decisions may be impacted by the ability to
1371 // isolate perfectly nested bands.
1372 (void)tryIsolateBands(tileLoops);
1373
1374 return tileLoops;
1375}
1376
1377scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1378 scf::ForallOp source,
1379 RewriterBase &rewriter) {
1380 unsigned numTargetOuts = target.getNumResults();
1381 unsigned numSourceOuts = source.getNumResults();
1382
1383 // Create fused shared_outs.
1384 SmallVector<Value> fusedOuts;
1385 llvm::append_range(fusedOuts, target.getOutputs());
1386 llvm::append_range(fusedOuts, source.getOutputs());
1387
1388 // Create a new scf.forall op after the source loop.
1389 rewriter.setInsertionPointAfter(source);
1390 scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
1391 source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
1392 source.getMixedStep(), fusedOuts, source.getMapping());
1393
1394 // Map control operands.
1395 IRMapping mapping;
1396 mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1397 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1398
1399 // Map shared outs.
1400 mapping.map(target.getRegionIterArgs(),
1401 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1402 mapping.map(source.getRegionIterArgs(),
1403 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1404
1405 // Append everything except the terminator into the fused operation.
1406 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1407 for (Operation &op : target.getBody()->without_terminator())
1408 rewriter.clone(op, mapping);
1409 for (Operation &op : source.getBody()->without_terminator())
1410 rewriter.clone(op, mapping);
1411
1412 // Fuse the old terminator in_parallel ops into the new one.
1413 scf::InParallelOp targetTerm = target.getTerminator();
1414 scf::InParallelOp sourceTerm = source.getTerminator();
1415 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1416 rewriter.setInsertionPointToStart(fusedTerm.getBody());
1417 for (Operation &op : targetTerm.getYieldingOps())
1418 rewriter.clone(op, mapping);
1419 for (Operation &op : sourceTerm.getYieldingOps())
1420 rewriter.clone(op, mapping);
1421
1422 // Replace old loops by substituting their uses by results of the fused loop.
1423 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1424 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1425
1426 return fusedLoop;
1427}
1428
1429scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1430 scf::ForOp source,
1431 RewriterBase &rewriter) {
1432 unsigned numTargetOuts = target.getNumResults();
1433 unsigned numSourceOuts = source.getNumResults();
1434
1435 // Create fused init_args, with target's init_args before source's init_args.
1436 SmallVector<Value> fusedInitArgs;
1437 llvm::append_range(fusedInitArgs, target.getInitArgs());
1438 llvm::append_range(fusedInitArgs, source.getInitArgs());
1439
1440 // Create a new scf.for op after the source loop (with scf.yield terminator
1441 // (without arguments) only in case its init_args is empty).
1442 rewriter.setInsertionPointAfter(source);
1443 scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
1444 source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1445 source.getStep(), fusedInitArgs);
1446
1447 // Map original induction variables and operands to those of the fused loop.
1448 IRMapping mapping;
1449 mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1450 mapping.map(target.getRegionIterArgs(),
1451 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1452 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1453 mapping.map(source.getRegionIterArgs(),
1454 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1455
1456 // Merge target's body into the new (fused) for loop and then source's body.
1457 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1458 for (Operation &op : target.getBody()->without_terminator())
1459 rewriter.clone(op, mapping);
1460 for (Operation &op : source.getBody()->without_terminator())
1461 rewriter.clone(op, mapping);
1462
1463 // Build fused yield results by appropriately mapping original yield operands.
1464 SmallVector<Value> yieldResults;
1465 for (Value operand : target.getBody()->getTerminator()->getOperands())
1466 yieldResults.push_back(mapping.lookupOrDefault(operand));
1467 for (Value operand : source.getBody()->getTerminator()->getOperands())
1468 yieldResults.push_back(mapping.lookupOrDefault(operand));
1469 if (!yieldResults.empty())
1470 rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1471
1472 // Replace old loops by substituting their uses by results of the fused loop.
1473 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1474 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1475
1476 return fusedLoop;
1477}
1478
1479FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1480 scf::ForallOp forallOp) {
1481 SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1482 SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1483 SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1484
1485 if (forallOp.isNormalized())
1486 return forallOp;
1487
1488 OpBuilder::InsertionGuard g(rewriter);
1489 auto loc = forallOp.getLoc();
1490 rewriter.setInsertionPoint(forallOp);
1491 SmallVector<OpFoldResult> newUbs;
1492 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1493 Range normalizedLoopParams =
1494 emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1495 newUbs.push_back(normalizedLoopParams.size);
1496 }
1497 (void)foldDynamicIndexList(ofrs&: newUbs);
1498
1499 // Use the normalized builder since the lower bounds are always 0 and the
1500 // steps are always 1.
1501 auto normalizedForallOp = rewriter.create<scf::ForallOp>(
1502 loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1503 [](OpBuilder &, Location, ValueRange) {});
1504
1505 rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1506 normalizedForallOp.getBodyRegion(),
1507 normalizedForallOp.getBodyRegion().begin());
1508 // Remove the original empty block in the new loop.
1509 rewriter.eraseBlock(block: &normalizedForallOp.getBodyRegion().back());
1510
1511 rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1512 // Update the users of the original loop variables.
1513 for (auto [idx, iv] :
1514 llvm::enumerate(normalizedForallOp.getInductionVars())) {
1515 auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1516 auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1517 denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1518 }
1519
1520 rewriter.replaceOp(forallOp, normalizedForallOp);
1521 return normalizedForallOp;
1522}
1523

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SCF/Utils/Utils.cpp