1//===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous loop transformation routines.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Utils/Utils.h"
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Interfaces/SideEffectInterfaces.h"
24#include "mlir/Transforms/RegionUtils.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/MathExtras.h"
29#include <cstdint>
30
31using namespace mlir;
32
33#define DEBUG_TYPE "scf-utils"
34#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
36
37SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
38 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
39 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
40 bool replaceIterOperandsUsesInLoop) {
41 if (loopNest.empty())
42 return {};
43 // This method is recursive (to make it more readable). Adding an
44 // assertion here to limit the recursion. (See
45 // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
46 assert(loopNest.size() <= 10 &&
47 "exceeded recursion limit when yielding value from loop nest");
48
49 // To yield a value from a perfectly nested loop nest, the following
50 // pattern needs to be created, i.e. starting with
51 //
52 // ```mlir
53 // scf.for .. {
54 // scf.for .. {
55 // scf.for .. {
56 // %value = ...
57 // }
58 // }
59 // }
60 // ```
61 //
62 // needs to be modified to
63 //
64 // ```mlir
65 // %0 = scf.for .. iter_args(%arg0 = %init) {
66 // %1 = scf.for .. iter_args(%arg1 = %arg0) {
67 // %2 = scf.for .. iter_args(%arg2 = %arg1) {
68 // %value = ...
69 // scf.yield %value
70 // }
71 // scf.yield %2
72 // }
73 // scf.yield %1
74 // }
75 // ```
76 //
77 // The inner most loop is handled using the `replaceWithAdditionalYields`
78 // that works on a single loop.
79 if (loopNest.size() == 1) {
80 auto innerMostLoop =
81 cast<scf::ForOp>(Val: *loopNest.back().replaceWithAdditionalYields(
82 rewriter, newInitOperands: newIterOperands, replaceInitOperandUsesInLoop: replaceIterOperandsUsesInLoop,
83 newYieldValuesFn));
84 return {innerMostLoop};
85 }
86 // The outer loops are modified by calling this method recursively
87 // - The return value of the inner loop is the value yielded by this loop.
88 // - The region iter args of this loop are the init_args for the inner loop.
89 SmallVector<scf::ForOp> newLoopNest;
90 NewYieldValuesFn fn =
91 [&](OpBuilder &innerBuilder, Location loc,
92 ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
93 newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest: loopNest.drop_front(),
94 newIterOperands: innerNewBBArgs, newYieldValuesFn,
95 replaceIterOperandsUsesInLoop);
96 return llvm::to_vector(Range: llvm::map_range(
97 C: newLoopNest.front().getResults().take_back(n: innerNewBBArgs.size()),
98 F: [](OpResult r) -> Value { return r; }));
99 };
100 scf::ForOp outerMostLoop =
101 cast<scf::ForOp>(Val: *loopNest.front().replaceWithAdditionalYields(
102 rewriter, newInitOperands: newIterOperands, replaceInitOperandUsesInLoop: replaceIterOperandsUsesInLoop, newYieldValuesFn: fn));
103 newLoopNest.insert(I: newLoopNest.begin(), Elt: outerMostLoop);
104 return newLoopNest;
105}
106
107/// Outline a region with a single block into a new FuncOp.
108/// Assumes the FuncOp result types is the type of the yielded operands of the
109/// single block. This constraint makes it easy to determine the result.
110/// This method also clones the `arith::ConstantIndexOp` at the start of
111/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
112/// provided, it will be set to point to the operation that calls the outlined
113/// function.
114// TODO: support more than single-block regions.
115// TODO: more flexible constant handling.
116FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
117 Location loc,
118 Region &region,
119 StringRef funcName,
120 func::CallOp *callOp) {
121 assert(!funcName.empty() && "funcName cannot be empty");
122 if (!region.hasOneBlock())
123 return failure();
124
125 Block *originalBlock = &region.front();
126 Operation *originalTerminator = originalBlock->getTerminator();
127
128 // Outline before current function.
129 OpBuilder::InsertionGuard g(rewriter);
130 rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
131
132 SetVector<Value> captures;
133 getUsedValuesDefinedAbove(regions: region, values&: captures);
134
135 ValueRange outlinedValues(captures.getArrayRef());
136 SmallVector<Type> outlinedFuncArgTypes;
137 SmallVector<Location> outlinedFuncArgLocs;
138 // Region's arguments are exactly the first block's arguments as per
139 // Region::getArguments().
140 // Func's arguments are cat(regions's arguments, captures arguments).
141 for (BlockArgument arg : region.getArguments()) {
142 outlinedFuncArgTypes.push_back(Elt: arg.getType());
143 outlinedFuncArgLocs.push_back(Elt: arg.getLoc());
144 }
145 for (Value value : outlinedValues) {
146 outlinedFuncArgTypes.push_back(Elt: value.getType());
147 outlinedFuncArgLocs.push_back(Elt: value.getLoc());
148 }
149 FunctionType outlinedFuncType =
150 FunctionType::get(context: rewriter.getContext(), inputs: outlinedFuncArgTypes,
151 results: originalTerminator->getOperandTypes());
152 auto outlinedFunc =
153 rewriter.create<func::FuncOp>(location: loc, args&: funcName, args&: outlinedFuncType);
154 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
155
156 // Merge blocks while replacing the original block operands.
157 // Warning: `mergeBlocks` erases the original block, reconstruct it later.
158 int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
159 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
160 {
161 OpBuilder::InsertionGuard g(rewriter);
162 rewriter.setInsertionPointToEnd(outlinedFuncBody);
163 rewriter.mergeBlocks(
164 source: originalBlock, dest: outlinedFuncBody,
165 argValues: outlinedFuncBlockArgs.take_front(N: numOriginalBlockArguments));
166 // Explicitly set up a new ReturnOp terminator.
167 rewriter.setInsertionPointToEnd(outlinedFuncBody);
168 rewriter.create<func::ReturnOp>(location: loc, args: originalTerminator->getResultTypes(),
169 args: originalTerminator->getOperands());
170 }
171
172 // Reconstruct the block that was deleted and add a
173 // terminator(call_results).
174 Block *newBlock = rewriter.createBlock(
175 parent: &region, insertPt: region.begin(),
176 argTypes: TypeRange{outlinedFuncArgTypes}.take_front(n: numOriginalBlockArguments),
177 locs: ArrayRef<Location>(outlinedFuncArgLocs)
178 .take_front(N: numOriginalBlockArguments));
179 {
180 OpBuilder::InsertionGuard g(rewriter);
181 rewriter.setInsertionPointToEnd(newBlock);
182 SmallVector<Value> callValues;
183 llvm::append_range(C&: callValues, R: newBlock->getArguments());
184 llvm::append_range(C&: callValues, R&: outlinedValues);
185 auto call = rewriter.create<func::CallOp>(location: loc, args&: outlinedFunc, args&: callValues);
186 if (callOp)
187 *callOp = call;
188
189 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
190 // Clone `originalTerminator` to take the callOp results then erase it from
191 // `outlinedFuncBody`.
192 IRMapping bvm;
193 bvm.map(from: originalTerminator->getOperands(), to: call->getResults());
194 rewriter.clone(op&: *originalTerminator, mapper&: bvm);
195 rewriter.eraseOp(op: originalTerminator);
196 }
197
198 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
199 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
200 for (auto it : llvm::zip(t&: outlinedValues, u: outlinedFuncBlockArgs.take_back(
201 N: outlinedValues.size()))) {
202 Value orig = std::get<0>(t&: it);
203 Value repl = std::get<1>(t&: it);
204 {
205 OpBuilder::InsertionGuard g(rewriter);
206 rewriter.setInsertionPointToStart(outlinedFuncBody);
207 if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
208 IRMapping bvm;
209 repl = rewriter.clone(op&: *cst, mapper&: bvm)->getResult(idx: 0);
210 }
211 }
212 orig.replaceUsesWithIf(newValue: repl, shouldReplace: [&](OpOperand &opOperand) {
213 return outlinedFunc->isProperAncestor(other: opOperand.getOwner());
214 });
215 }
216
217 return outlinedFunc;
218}
219
220LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
221 func::FuncOp *thenFn, StringRef thenFnName,
222 func::FuncOp *elseFn, StringRef elseFnName) {
223 IRRewriter rewriter(b);
224 Location loc = ifOp.getLoc();
225 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
226 if (thenFn && !ifOp.getThenRegion().empty()) {
227 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
228 rewriter, loc, region&: ifOp.getThenRegion(), funcName: thenFnName);
229 if (failed(Result: outlinedFuncOpOrFailure))
230 return failure();
231 *thenFn = *outlinedFuncOpOrFailure;
232 }
233 if (elseFn && !ifOp.getElseRegion().empty()) {
234 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
235 rewriter, loc, region&: ifOp.getElseRegion(), funcName: elseFnName);
236 if (failed(Result: outlinedFuncOpOrFailure))
237 return failure();
238 *elseFn = *outlinedFuncOpOrFailure;
239 }
240 return success();
241}
242
243bool mlir::getInnermostParallelLoops(Operation *rootOp,
244 SmallVectorImpl<scf::ParallelOp> &result) {
245 assert(rootOp != nullptr && "Root operation must not be a nullptr.");
246 bool rootEnclosesPloops = false;
247 for (Region &region : rootOp->getRegions()) {
248 for (Block &block : region.getBlocks()) {
249 for (Operation &op : block) {
250 bool enclosesPloops = getInnermostParallelLoops(rootOp: &op, result);
251 rootEnclosesPloops |= enclosesPloops;
252 if (auto ploop = dyn_cast<scf::ParallelOp>(Val&: op)) {
253 rootEnclosesPloops = true;
254
255 // Collect parallel loop if it is an innermost one.
256 if (!enclosesPloops)
257 result.push_back(Elt: ploop);
258 }
259 }
260 }
261 }
262 return rootEnclosesPloops;
263}
264
265// Build the IR that performs ceil division of a positive value by a constant:
266// ceildiv(a, B) = divis(a + (B-1), B)
267// where divis is rounding-to-zero division.
268static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
269 int64_t divisor) {
270 assert(divisor > 0 && "expected positive divisor");
271 assert(dividend.getType().isIntOrIndex() &&
272 "expected integer or index-typed value");
273
274 Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
275 location: loc, args: builder.getIntegerAttr(type: dividend.getType(), value: divisor - 1));
276 Value divisorCst = builder.create<arith::ConstantOp>(
277 location: loc, args: builder.getIntegerAttr(type: dividend.getType(), value: divisor));
278 Value sum = builder.create<arith::AddIOp>(location: loc, args&: dividend, args&: divisorMinusOneCst);
279 return builder.create<arith::DivUIOp>(location: loc, args&: sum, args&: divisorCst);
280}
281
282// Build the IR that performs ceil division of a positive value by another
283// positive value:
284// ceildiv(a, b) = divis(a + (b - 1), b)
285// where divis is rounding-to-zero division.
286static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
287 Value divisor) {
288 assert(dividend.getType().isIntOrIndex() &&
289 "expected integer or index-typed value");
290 Value cstOne = builder.create<arith::ConstantOp>(
291 location: loc, args: builder.getOneAttr(type: dividend.getType()));
292 Value divisorMinusOne = builder.create<arith::SubIOp>(location: loc, args&: divisor, args&: cstOne);
293 Value sum = builder.create<arith::AddIOp>(location: loc, args&: dividend, args&: divisorMinusOne);
294 return builder.create<arith::DivUIOp>(location: loc, args&: sum, args&: divisor);
295}
296
297/// Returns the trip count of `forOp` if its' low bound, high bound and step are
298/// constants, or optional otherwise. Trip count is computed as
299/// ceilDiv(highBound - lowBound, step).
300static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
301 return constantTripCount(lb: forOp.getLowerBound(), ub: forOp.getUpperBound(),
302 step: forOp.getStep());
303}
304
305/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
306/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
307/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
308/// unrolled iteration using annotateFn.
309static void generateUnrolledLoop(
310 Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
311 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
312 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
313 ValueRange iterArgs, ValueRange yieldedValues) {
314 // Builder to insert unrolled bodies just before the terminator of the body of
315 // 'forOp'.
316 auto builder = OpBuilder::atBlockTerminator(block: loopBodyBlock);
317
318 constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
319 if (!annotateFn)
320 annotateFn = defaultAnnotateFn;
321
322 // Keep a pointer to the last non-terminator operation in the original block
323 // so that we know what to clone (since we are doing this in-place).
324 Block::iterator srcBlockEnd = std::prev(x: loopBodyBlock->end(), n: 2);
325
326 // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
327 SmallVector<Value, 4> lastYielded(yieldedValues);
328
329 for (unsigned i = 1; i < unrollFactor; i++) {
330 IRMapping operandMap;
331
332 // Prepare operand map.
333 operandMap.map(from&: iterArgs, to&: lastYielded);
334
335 // If the induction variable is used, create a remapping to the value for
336 // this unrolled instance.
337 if (!forOpIV.use_empty()) {
338 Value ivUnroll = ivRemapFn(i, forOpIV, builder);
339 operandMap.map(from: forOpIV, to: ivUnroll);
340 }
341
342 // Clone the original body of 'forOp'.
343 for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++) {
344 Operation *clonedOp = builder.clone(op&: *it, mapper&: operandMap);
345 annotateFn(i, clonedOp, builder);
346 }
347
348 // Update yielded values.
349 for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
350 lastYielded[i] = operandMap.lookupOrDefault(from: yieldedValues[i]);
351 }
352
353 // Make sure we annotate the Ops in the original body. We do this last so that
354 // any annotations are not copied into the cloned Ops above.
355 for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++)
356 annotateFn(0, &*it, builder);
357
358 // Update operands of the yield statement.
359 loopBodyBlock->getTerminator()->setOperands(lastYielded);
360}
361
362/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
363/// epilogue loop, if the loop is unrolled.
364FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
365 scf::ForOp forOp, uint64_t unrollFactor,
366 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
367 assert(unrollFactor > 0 && "expected positive unroll factor");
368
369 // Return if the loop body is empty.
370 if (llvm::hasSingleElement(C&: forOp.getBody()->getOperations()))
371 return UnrolledLoopInfo{.mainLoopOp: forOp, .epilogueLoopOp: std::nullopt};
372
373 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
374 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
375 OpBuilder boundsBuilder(forOp);
376 IRRewriter rewriter(forOp.getContext());
377 auto loc = forOp.getLoc();
378 Value step = forOp.getStep();
379 Value upperBoundUnrolled;
380 Value stepUnrolled;
381 bool generateEpilogueLoop = true;
382
383 std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
384 if (constTripCount) {
385 // Constant loop bounds computation.
386 int64_t lbCst = getConstantIntValue(ofr: forOp.getLowerBound()).value();
387 int64_t ubCst = getConstantIntValue(ofr: forOp.getUpperBound()).value();
388 int64_t stepCst = getConstantIntValue(ofr: forOp.getStep()).value();
389 if (unrollFactor == 1) {
390 if (*constTripCount == 1 &&
391 failed(Result: forOp.promoteIfSingleIteration(rewriter)))
392 return failure();
393 return UnrolledLoopInfo{.mainLoopOp: forOp, .epilogueLoopOp: std::nullopt};
394 }
395
396 int64_t tripCountEvenMultiple =
397 *constTripCount - (*constTripCount % unrollFactor);
398 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
399 int64_t stepUnrolledCst = stepCst * unrollFactor;
400
401 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
402 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
403 if (generateEpilogueLoop)
404 upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
405 location: loc, args: boundsBuilder.getIntegerAttr(type: forOp.getUpperBound().getType(),
406 value: upperBoundUnrolledCst));
407 else
408 upperBoundUnrolled = forOp.getUpperBound();
409
410 // Create constant for 'stepUnrolled'.
411 stepUnrolled = stepCst == stepUnrolledCst
412 ? step
413 : boundsBuilder.create<arith::ConstantOp>(
414 location: loc, args: boundsBuilder.getIntegerAttr(
415 type: step.getType(), value: stepUnrolledCst));
416 } else {
417 // Dynamic loop bounds computation.
418 // TODO: Add dynamic asserts for negative lb/ub/step, or
419 // consider using ceilDiv from AffineApplyExpander.
420 auto lowerBound = forOp.getLowerBound();
421 auto upperBound = forOp.getUpperBound();
422 Value diff =
423 boundsBuilder.create<arith::SubIOp>(location: loc, args&: upperBound, args&: lowerBound);
424 Value tripCount = ceilDivPositive(builder&: boundsBuilder, loc, dividend: diff, divisor: step);
425 Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
426 location: loc, args: boundsBuilder.getIntegerAttr(type: tripCount.getType(), value: unrollFactor));
427 Value tripCountRem =
428 boundsBuilder.create<arith::RemSIOp>(location: loc, args&: tripCount, args&: unrollFactorCst);
429 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
430 Value tripCountEvenMultiple =
431 boundsBuilder.create<arith::SubIOp>(location: loc, args&: tripCount, args&: tripCountRem);
432 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
433 upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
434 location: loc, args&: lowerBound,
435 args: boundsBuilder.create<arith::MulIOp>(location: loc, args&: tripCountEvenMultiple, args&: step));
436 // Scale 'step' by 'unrollFactor'.
437 stepUnrolled =
438 boundsBuilder.create<arith::MulIOp>(location: loc, args&: step, args&: unrollFactorCst);
439 }
440
441 UnrolledLoopInfo resultLoops;
442
443 // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
444 if (generateEpilogueLoop) {
445 OpBuilder epilogueBuilder(forOp->getContext());
446 epilogueBuilder.setInsertionPointAfter(forOp);
447 auto epilogueForOp = cast<scf::ForOp>(Val: epilogueBuilder.clone(op&: *forOp));
448 epilogueForOp.setLowerBound(upperBoundUnrolled);
449
450 // Update uses of loop results.
451 auto results = forOp.getResults();
452 auto epilogueResults = epilogueForOp.getResults();
453
454 for (auto e : llvm::zip(t&: results, u&: epilogueResults)) {
455 std::get<0>(t&: e).replaceAllUsesWith(newValue: std::get<1>(t&: e));
456 }
457 epilogueForOp->setOperands(start: epilogueForOp.getNumControlOperands(),
458 length: epilogueForOp.getInitArgs().size(), operands: results);
459 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
460 resultLoops.epilogueLoopOp = epilogueForOp;
461 }
462
463 // Create unrolled loop.
464 forOp.setUpperBound(upperBoundUnrolled);
465 forOp.setStep(stepUnrolled);
466
467 auto iterArgs = ValueRange(forOp.getRegionIterArgs());
468 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
469
470 generateUnrolledLoop(
471 loopBodyBlock: forOp.getBody(), forOpIV: forOp.getInductionVar(), unrollFactor,
472 ivRemapFn: [&](unsigned i, Value iv, OpBuilder b) {
473 // iv' = iv + step * i;
474 auto stride = b.create<arith::MulIOp>(
475 location: loc, args&: step,
476 args: b.create<arith::ConstantOp>(location: loc,
477 args: b.getIntegerAttr(type: iv.getType(), value: i)));
478 return b.create<arith::AddIOp>(location: loc, args&: iv, args&: stride);
479 },
480 annotateFn, iterArgs, yieldedValues);
481 // Promote the loop body up if this has turned into a single iteration loop.
482 if (forOp.promoteIfSingleIteration(rewriter).failed())
483 resultLoops.mainLoopOp = forOp;
484 return resultLoops;
485}
486
487/// Unrolls this loop completely.
488LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
489 IRRewriter rewriter(forOp.getContext());
490 std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
491 if (!mayBeConstantTripCount.has_value())
492 return failure();
493 uint64_t tripCount = *mayBeConstantTripCount;
494 if (tripCount == 0)
495 return success();
496 if (tripCount == 1)
497 return forOp.promoteIfSingleIteration(rewriter);
498 return loopUnrollByFactor(forOp, unrollFactor: tripCount);
499}
500
501/// Check if bounds of all inner loops are defined outside of `forOp`
502/// and return false if not.
503static bool areInnerBoundsInvariant(scf::ForOp forOp) {
504 auto walkResult = forOp.walk(callback: [&](scf::ForOp innerForOp) {
505 if (!forOp.isDefinedOutsideOfLoop(value: innerForOp.getLowerBound()) ||
506 !forOp.isDefinedOutsideOfLoop(value: innerForOp.getUpperBound()) ||
507 !forOp.isDefinedOutsideOfLoop(value: innerForOp.getStep()))
508 return WalkResult::interrupt();
509
510 return WalkResult::advance();
511 });
512 return !walkResult.wasInterrupted();
513}
514
515/// Unrolls and jams this loop by the specified factor.
516LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
517 uint64_t unrollJamFactor) {
518 assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
519
520 if (unrollJamFactor == 1)
521 return success();
522
523 // If any control operand of any inner loop of `forOp` is defined within
524 // `forOp`, no unroll jam.
525 if (!areInnerBoundsInvariant(forOp)) {
526 LDBG("failed to unroll and jam: inner bounds are not invariant");
527 return failure();
528 }
529
530 // Currently, for operations with results are not supported.
531 if (forOp->getNumResults() > 0) {
532 LDBG("failed to unroll and jam: unsupported loop with results");
533 return failure();
534 }
535
536 // Currently, only constant trip count that divided by the unroll factor is
537 // supported.
538 std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
539 if (!tripCount.has_value()) {
540 // If the trip count is dynamic, do not unroll & jam.
541 LDBG("failed to unroll and jam: trip count could not be determined");
542 return failure();
543 }
544 if (unrollJamFactor > *tripCount) {
545 LDBG("unroll and jam factor is greater than trip count, set factor to trip "
546 "count");
547 unrollJamFactor = *tripCount;
548 } else if (*tripCount % unrollJamFactor != 0) {
549 LDBG("failed to unroll and jam: unsupported trip count that is not a "
550 "multiple of unroll jam factor");
551 return failure();
552 }
553
554 // Nothing in the loop body other than the terminator.
555 if (llvm::hasSingleElement(C&: forOp.getBody()->getOperations()))
556 return success();
557
558 // Gather all sub-blocks to jam upon the loop being unrolled.
559 JamBlockGatherer<scf::ForOp> jbg;
560 jbg.walk(op: forOp);
561 auto &subBlocks = jbg.subBlocks;
562
563 // Collect inner loops.
564 SmallVector<scf::ForOp> innerLoops;
565 forOp.walk(callback: [&](scf::ForOp innerForOp) { innerLoops.push_back(Elt: innerForOp); });
566
567 // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
568 // iteration. There are (`unrollJamFactor` - 1) iterations.
569 SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
570
571 // For any loop with iter_args, replace it with a new loop that has
572 // `unrollJamFactor` copies of its iterOperands, iter_args and yield
573 // operands.
574 SmallVector<scf::ForOp> newInnerLoops;
575 IRRewriter rewriter(forOp.getContext());
576 for (scf::ForOp oldForOp : innerLoops) {
577 SmallVector<Value> dupIterOperands, dupYieldOperands;
578 ValueRange oldIterOperands = oldForOp.getInits();
579 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
580 ValueRange oldYieldOperands =
581 cast<scf::YieldOp>(Val: oldForOp.getBody()->getTerminator()).getOperands();
582 // Get additional iterOperands, iterArgs, and yield operands. We will
583 // fix iterOperands and yield operands after cloning of sub-blocks.
584 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
585 dupIterOperands.append(in_start: oldIterOperands.begin(), in_end: oldIterOperands.end());
586 dupYieldOperands.append(in_start: oldYieldOperands.begin(), in_end: oldYieldOperands.end());
587 }
588 // Create a new loop with additional iterOperands, iter_args and yield
589 // operands. This new loop will take the loop body of the original loop.
590 bool forOpReplaced = oldForOp == forOp;
591 scf::ForOp newForOp =
592 cast<scf::ForOp>(Val: *oldForOp.replaceWithAdditionalYields(
593 rewriter, newInitOperands: dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
594 newYieldValuesFn: [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
595 return dupYieldOperands;
596 }));
597 newInnerLoops.push_back(Elt: newForOp);
598 // `forOp` has been replaced with a new loop.
599 if (forOpReplaced)
600 forOp = newForOp;
601 // Update `operandMaps` for `newForOp` iterArgs and results.
602 ValueRange newIterArgs = newForOp.getRegionIterArgs();
603 unsigned oldNumIterArgs = oldIterArgs.size();
604 ValueRange newResults = newForOp.getResults();
605 unsigned oldNumResults = newResults.size() / unrollJamFactor;
606 assert(oldNumIterArgs == oldNumResults &&
607 "oldNumIterArgs must be the same as oldNumResults");
608 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
609 for (unsigned j = 0; j < oldNumIterArgs; ++j) {
610 // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
611 // results. Update `operandMaps[i - 1]` to map old iterArgs and results
612 // to those in the `i`th new set.
613 operandMaps[i - 1].map(from: newIterArgs[j],
614 to: newIterArgs[i * oldNumIterArgs + j]);
615 operandMaps[i - 1].map(from: newResults[j],
616 to: newResults[i * oldNumResults + j]);
617 }
618 }
619 }
620
621 // Scale the step of loop being unroll-jammed by the unroll-jam factor.
622 rewriter.setInsertionPoint(forOp);
623 int64_t step = forOp.getConstantStep()->getSExtValue();
624 auto newStep = rewriter.createOrFold<arith::MulIOp>(
625 location: forOp.getLoc(), args: forOp.getStep(),
626 args: rewriter.createOrFold<arith::ConstantOp>(
627 location: forOp.getLoc(), args: rewriter.getIndexAttr(value: unrollJamFactor)));
628 forOp.setStep(newStep);
629 auto forOpIV = forOp.getInductionVar();
630
631 // Unroll and jam (appends unrollJamFactor - 1 additional copies).
632 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
633 for (auto &subBlock : subBlocks) {
634 // Builder to insert unroll-jammed bodies. Insert right at the end of
635 // sub-block.
636 OpBuilder builder(subBlock.first->getBlock(), std::next(x: subBlock.second));
637
638 // If the induction variable is used, create a remapping to the value for
639 // this unrolled instance.
640 if (!forOpIV.use_empty()) {
641 // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
642 auto ivTag = builder.createOrFold<arith::ConstantOp>(
643 location: forOp.getLoc(), args: builder.getIndexAttr(value: step * i));
644 auto ivUnroll =
645 builder.createOrFold<arith::AddIOp>(location: forOp.getLoc(), args&: forOpIV, args&: ivTag);
646 operandMaps[i - 1].map(from: forOpIV, to: ivUnroll);
647 }
648 // Clone the sub-block being unroll-jammed.
649 for (auto it = subBlock.first; it != std::next(x: subBlock.second); ++it)
650 builder.clone(op&: *it, mapper&: operandMaps[i - 1]);
651 }
652 // Fix iterOperands and yield op operands of newly created loops.
653 for (auto newForOp : newInnerLoops) {
654 unsigned oldNumIterOperands =
655 newForOp.getNumRegionIterArgs() / unrollJamFactor;
656 unsigned numControlOperands = newForOp.getNumControlOperands();
657 auto yieldOp = cast<scf::YieldOp>(Val: newForOp.getBody()->getTerminator());
658 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
659 assert(oldNumIterOperands == oldNumYieldOperands &&
660 "oldNumIterOperands must be the same as oldNumYieldOperands");
661 for (unsigned j = 0; j < oldNumIterOperands; ++j) {
662 // The `i`th duplication of an old iterOperand or yield op operand
663 // needs to be replaced with a mapped value from `operandMaps[i - 1]`
664 // if such mapped value exists.
665 newForOp.setOperand(i: numControlOperands + i * oldNumIterOperands + j,
666 value: operandMaps[i - 1].lookupOrDefault(
667 from: newForOp.getOperand(i: numControlOperands + j)));
668 yieldOp.setOperand(
669 i: i * oldNumYieldOperands + j,
670 value: operandMaps[i - 1].lookupOrDefault(from: yieldOp.getOperand(i: j)));
671 }
672 }
673 }
674
675 // Promote the loop body up if this has turned into a single iteration loop.
676 (void)forOp.promoteIfSingleIteration(rewriter);
677 return success();
678}
679
680Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
681 OpFoldResult lb, OpFoldResult ub,
682 OpFoldResult step) {
683 Range normalizedLoopBounds;
684 normalizedLoopBounds.offset = rewriter.getIndexAttr(value: 0);
685 normalizedLoopBounds.stride = rewriter.getIndexAttr(value: 1);
686 AffineExpr s0, s1, s2;
687 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
688 AffineExpr e = (s1 - s0).ceilDiv(other: s2);
689 normalizedLoopBounds.size =
690 affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: e, operands: {lb, ub, step});
691 return normalizedLoopBounds;
692}
693
694Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
695 OpFoldResult lb, OpFoldResult ub,
696 OpFoldResult step) {
697 if (getType(ofr: lb).isIndex()) {
698 return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
699 }
700 // For non-index types, generate `arith` instructions
701 // Check if the loop is already known to have a constant zero lower bound or
702 // a constant one step.
703 bool isZeroBased = false;
704 if (auto lbCst = getConstantIntValue(ofr: lb))
705 isZeroBased = lbCst.value() == 0;
706
707 bool isStepOne = false;
708 if (auto stepCst = getConstantIntValue(ofr: step))
709 isStepOne = stepCst.value() == 1;
710
711 Type rangeType = getType(ofr: lb);
712 assert(rangeType == getType(ub) && rangeType == getType(step) &&
713 "expected matching types");
714
715 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
716 // assuming the step is strictly positive. Update the bounds and the step
717 // of the loop to go from 0 to the number of iterations, if necessary.
718 if (isZeroBased && isStepOne)
719 return {.offset: lb, .size: ub, .stride: step};
720
721 OpFoldResult diff = ub;
722 if (!isZeroBased) {
723 diff = rewriter.createOrFold<arith::SubIOp>(
724 location: loc, args: getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: ub),
725 args: getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: lb));
726 }
727 OpFoldResult newUpperBound = diff;
728 if (!isStepOne) {
729 newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
730 location: loc, args: getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: diff),
731 args: getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: step));
732 }
733
734 OpFoldResult newLowerBound = rewriter.getZeroAttr(type: rangeType);
735 OpFoldResult newStep = rewriter.getOneAttr(type: rangeType);
736
737 return {.offset: newLowerBound, .size: newUpperBound, .stride: newStep};
738}
739
740static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
741 Location loc,
742 Value normalizedIv,
743 OpFoldResult origLb,
744 OpFoldResult origStep) {
745 AffineExpr d0, s0, s1;
746 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
747 bindDims(ctx: rewriter.getContext(), exprs&: d0);
748 AffineExpr e = d0 * s1 + s0;
749 OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
750 b&: rewriter, loc, expr: e, operands: ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
751 Value denormalizedIvVal =
752 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: denormalizedIv);
753 SmallPtrSet<Operation *, 1> preservedUses;
754 // If an `affine.apply` operation is generated for denormalization, the use
755 // of `origLb` in those ops must not be replaced. These arent not generated
756 // when `origLb == 0` and `origStep == 1`.
757 if (!isZeroInteger(v: origLb) || !isOneInteger(v: origStep)) {
758 if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
759 preservedUses.insert(Ptr: preservedUse);
760 }
761 }
762 rewriter.replaceAllUsesExcept(from: normalizedIv, to: denormalizedIvVal, preservedUsers: preservedUses);
763}
764
765void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
766 Value normalizedIv, OpFoldResult origLb,
767 OpFoldResult origStep) {
768 if (getType(ofr: origLb).isIndex()) {
769 return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
770 origLb, origStep);
771 }
772 Value denormalizedIv;
773 SmallPtrSet<Operation *, 2> preserve;
774 bool isStepOne = isOneInteger(v: origStep);
775 bool isZeroBased = isZeroInteger(v: origLb);
776
777 Value scaled = normalizedIv;
778 if (!isStepOne) {
779 Value origStepValue =
780 getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: origStep);
781 scaled = rewriter.create<arith::MulIOp>(location: loc, args&: normalizedIv, args&: origStepValue);
782 preserve.insert(Ptr: scaled.getDefiningOp());
783 }
784 denormalizedIv = scaled;
785 if (!isZeroBased) {
786 Value origLbValue = getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: origLb);
787 denormalizedIv = rewriter.create<arith::AddIOp>(location: loc, args&: scaled, args&: origLbValue);
788 preserve.insert(Ptr: denormalizedIv.getDefiningOp());
789 }
790
791 rewriter.replaceAllUsesExcept(from: normalizedIv, to: denormalizedIv, preservedUsers: preserve);
792}
793
794static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
795 ArrayRef<OpFoldResult> values) {
796 assert(!values.empty() && "unexecpted empty array");
797 AffineExpr s0, s1;
798 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
799 AffineExpr mul = s0 * s1;
800 OpFoldResult products = rewriter.getIndexAttr(value: 1);
801 for (auto v : values) {
802 products = affine::makeComposedFoldedAffineApply(
803 b&: rewriter, loc, expr: mul, operands: ArrayRef<OpFoldResult>{products, v});
804 }
805 return products;
806}
807
808/// Helper function to multiply a sequence of values.
809static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
810 ArrayRef<Value> values) {
811 assert(!values.empty() && "unexpected empty list");
812 if (getType(ofr: values.front()).isIndex()) {
813 SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
814 OpFoldResult product = getProductOfIndexes(rewriter, loc, values: ofrs);
815 return getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: product);
816 }
817 std::optional<Value> productOf;
818 for (auto v : values) {
819 auto vOne = getConstantIntValue(ofr: v);
820 if (vOne && vOne.value() == 1)
821 continue;
822 if (productOf)
823 productOf =
824 rewriter.create<arith::MulIOp>(location: loc, args&: productOf.value(), args&: v).getResult();
825 else
826 productOf = v;
827 }
828 if (!productOf) {
829 productOf = rewriter
830 .create<arith::ConstantOp>(
831 location: loc, args: rewriter.getOneAttr(type: getType(ofr: values.front())))
832 .getResult();
833 }
834 return productOf.value();
835}
836
837/// For each original loop, the value of the
838/// induction variable can be obtained by dividing the induction variable of
839/// the linearized loop by the total number of iterations of the loops nested
840/// in it modulo the number of iterations in this loop (remove the values
841/// related to the outer loops):
842/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
843/// Compute these iteratively from the innermost loop by creating a "running
844/// quotient" of division by the range.
845static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
846delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
847 Value linearizedIv, ArrayRef<Value> ubs) {
848
849 if (linearizedIv.getType().isIndex()) {
850 Operation *delinearizedOp =
851 rewriter.create<affine::AffineDelinearizeIndexOp>(location: loc, args&: linearizedIv,
852 args&: ubs);
853 auto resultVals = llvm::map_to_vector(
854 C: delinearizedOp->getResults(), F: [](OpResult r) -> Value { return r; });
855 return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
856 }
857
858 SmallVector<Value> delinearizedIvs(ubs.size());
859 SmallPtrSet<Operation *, 2> preservedUsers;
860
861 llvm::BitVector isUbOne(ubs.size());
862 for (auto [index, ub] : llvm::enumerate(First&: ubs)) {
863 auto ubCst = getConstantIntValue(ofr: ub);
864 if (ubCst && ubCst.value() == 1)
865 isUbOne.set(index);
866 }
867
868 // Prune the lead ubs that are all ones.
869 unsigned numLeadingOneUbs = 0;
870 for (auto [index, ub] : llvm::enumerate(First&: ubs)) {
871 if (!isUbOne.test(Idx: index)) {
872 break;
873 }
874 delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
875 location: loc, args: rewriter.getZeroAttr(type: ub.getType()));
876 numLeadingOneUbs++;
877 }
878
879 Value previous = linearizedIv;
880 for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
881 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
882 if (i != numLeadingOneUbs && !isUbOne.test(Idx: idx + 1)) {
883 previous = rewriter.create<arith::DivSIOp>(location: loc, args&: previous, args: ubs[idx + 1]);
884 preservedUsers.insert(Ptr: previous.getDefiningOp());
885 }
886 Value iv = previous;
887 if (i != e - 1) {
888 if (!isUbOne.test(Idx: idx)) {
889 iv = rewriter.create<arith::RemSIOp>(location: loc, args&: previous, args: ubs[idx]);
890 preservedUsers.insert(Ptr: iv.getDefiningOp());
891 } else {
892 iv = rewriter.create<arith::ConstantOp>(
893 location: loc, args: rewriter.getZeroAttr(type: ubs[idx].getType()));
894 }
895 }
896 delinearizedIvs[idx] = iv;
897 }
898 return {delinearizedIvs, preservedUsers};
899}
900
901LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
902 MutableArrayRef<scf::ForOp> loops) {
903 if (loops.size() < 2)
904 return failure();
905
906 scf::ForOp innermost = loops.back();
907 scf::ForOp outermost = loops.front();
908
909 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
910 // allows the following code to assume upperBound is the number of iterations.
911 for (auto loop : loops) {
912 OpBuilder::InsertionGuard g(rewriter);
913 rewriter.setInsertionPoint(outermost);
914 Value lb = loop.getLowerBound();
915 Value ub = loop.getUpperBound();
916 Value step = loop.getStep();
917 auto newLoopRange =
918 emitNormalizedLoopBounds(rewriter, loc: loop.getLoc(), lb, ub, step);
919
920 rewriter.modifyOpInPlace(root: loop, callable: [&]() {
921 loop.setLowerBound(getValueOrCreateConstantIntOp(b&: rewriter, loc: loop.getLoc(),
922 ofr: newLoopRange.offset));
923 loop.setUpperBound(getValueOrCreateConstantIntOp(b&: rewriter, loc: loop.getLoc(),
924 ofr: newLoopRange.size));
925 loop.setStep(getValueOrCreateConstantIntOp(b&: rewriter, loc: loop.getLoc(),
926 ofr: newLoopRange.stride));
927 });
928 rewriter.setInsertionPointToStart(innermost.getBody());
929 denormalizeInductionVariable(rewriter, loc: loop.getLoc(),
930 normalizedIv: loop.getInductionVar(), origLb: lb, origStep: step);
931 }
932
933 // 2. Emit code computing the upper bound of the coalesced loop as product
934 // of the number of iterations of all loops.
935 OpBuilder::InsertionGuard g(rewriter);
936 rewriter.setInsertionPoint(outermost);
937 Location loc = outermost.getLoc();
938 SmallVector<Value> upperBounds = llvm::map_to_vector(
939 C&: loops, F: [](auto loop) { return loop.getUpperBound(); });
940 Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, values: upperBounds);
941 outermost.setUpperBound(upperBound);
942
943 rewriter.setInsertionPointToStart(innermost.getBody());
944 auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
945 rewriter, loc, linearizedIv: outermost.getInductionVar(), ubs: upperBounds);
946 rewriter.replaceAllUsesExcept(from: outermost.getInductionVar(), to: delinearizeIvs[0],
947 preservedUsers);
948
949 for (int i = loops.size() - 1; i > 0; --i) {
950 auto outerLoop = loops[i - 1];
951 auto innerLoop = loops[i];
952
953 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
954 auto yieldedVals = llvm::to_vector(Range: innerTerminator->getOperands());
955 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
956 for (Value &yieldedVal : yieldedVals) {
957 // The yielded value may be an iteration argument of the inner loop
958 // which is about to be inlined.
959 auto iter = llvm::find(Range: innerLoop.getRegionIterArgs(), Val: yieldedVal);
960 if (iter != innerLoop.getRegionIterArgs().end()) {
961 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
962 // `outerLoop` iter args identical to the `innerLoop` init args.
963 assert(iterArgIndex < innerLoop.getInitArgs().size());
964 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
965 }
966 }
967 rewriter.eraseOp(op: innerTerminator);
968
969 SmallVector<Value> innerBlockArgs;
970 innerBlockArgs.push_back(Elt: delinearizeIvs[i]);
971 llvm::append_range(C&: innerBlockArgs, R: outerLoop.getRegionIterArgs());
972 rewriter.inlineBlockBefore(source: innerLoop.getBody(), dest: outerLoop.getBody(),
973 before: Block::iterator(innerLoop), argValues: innerBlockArgs);
974 rewriter.replaceOp(op: innerLoop, newValues: yieldedVals);
975 }
976 return success();
977}
978
979LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
980 if (loops.empty()) {
981 return failure();
982 }
983 IRRewriter rewriter(loops.front().getContext());
984 return coalesceLoops(rewriter, loops);
985}
986
987LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
988 LogicalResult result(failure());
989 SmallVector<scf::ForOp> loops;
990 getPerfectlyNestedLoops(nestedLoops&: loops, root: op);
991
992 // Look for a band of loops that can be coalesced, i.e. perfectly nested
993 // loops with bounds defined above some loop.
994
995 // 1. For each loop, find above which parent loop its bounds operands are
996 // defined.
997 SmallVector<unsigned> operandsDefinedAbove(loops.size());
998 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
999 operandsDefinedAbove[i] = i;
1000 for (unsigned j = 0; j < i; ++j) {
1001 SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1002 loops[i].getUpperBound(),
1003 loops[i].getStep()};
1004 if (areValuesDefinedAbove(values: boundsOperands, limit&: loops[j].getRegion())) {
1005 operandsDefinedAbove[i] = j;
1006 break;
1007 }
1008 }
1009 }
1010
1011 // 2. For each inner loop check that the iter_args for the immediately outer
1012 // loop are the init for the immediately inner loop and that the yields of the
1013 // return of the inner loop is the yield for the immediately outer loop. Keep
1014 // track of where the chain starts from for each loop.
1015 SmallVector<unsigned> iterArgChainStart(loops.size());
1016 iterArgChainStart[0] = 0;
1017 for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1018 // By default set the start of the chain to itself.
1019 iterArgChainStart[i] = i;
1020 auto outerloop = loops[i - 1];
1021 auto innerLoop = loops[i];
1022 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1023 continue;
1024 }
1025 if (!llvm::equal(LRange: outerloop.getRegionIterArgs(), RRange: innerLoop.getInitArgs())) {
1026 continue;
1027 }
1028 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1029 if (!llvm::equal(LRange: outerloopTerminator->getOperands(),
1030 RRange: innerLoop.getResults())) {
1031 continue;
1032 }
1033 iterArgChainStart[i] = iterArgChainStart[i - 1];
1034 }
1035
1036 // 3. Identify bands of loops such that the operands of all of them are
1037 // defined above the first loop in the band. Traverse the nest bottom-up
1038 // so that modifications don't invalidate the inner loops.
1039 for (unsigned end = loops.size(); end > 0; --end) {
1040 unsigned start = 0;
1041 for (; start < end - 1; ++start) {
1042 auto maxPos =
1043 *std::max_element(first: std::next(x: operandsDefinedAbove.begin(), n: start),
1044 last: std::next(x: operandsDefinedAbove.begin(), n: end));
1045 if (maxPos > start)
1046 continue;
1047 if (iterArgChainStart[end - 1] > start)
1048 continue;
1049 auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1050 if (succeeded(Result: coalesceLoops(loops: band)))
1051 result = success();
1052 break;
1053 }
1054 // If a band was found and transformed, keep looking at the loops above
1055 // the outermost transformed loop.
1056 if (start != end - 1)
1057 end = start + 1;
1058 }
1059 return result;
1060}
1061
1062void mlir::collapseParallelLoops(
1063 RewriterBase &rewriter, scf::ParallelOp loops,
1064 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1065 OpBuilder::InsertionGuard g(rewriter);
1066 rewriter.setInsertionPoint(loops);
1067 Location loc = loops.getLoc();
1068
1069 // Presort combined dimensions.
1070 auto sortedDimensions = llvm::to_vector<3>(Range&: combinedDimensions);
1071 for (auto &dims : sortedDimensions)
1072 llvm::sort(C&: dims);
1073
1074 // Normalize ParallelOp's iteration pattern.
1075 SmallVector<Value, 3> normalizedUpperBounds;
1076 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1077 OpBuilder::InsertionGuard g2(rewriter);
1078 rewriter.setInsertionPoint(loops);
1079 Value lb = loops.getLowerBound()[i];
1080 Value ub = loops.getUpperBound()[i];
1081 Value step = loops.getStep()[i];
1082 auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1083 normalizedUpperBounds.push_back(Elt: getValueOrCreateConstantIntOp(
1084 b&: rewriter, loc: loops.getLoc(), ofr: newLoopRange.size));
1085
1086 rewriter.setInsertionPointToStart(loops.getBody());
1087 denormalizeInductionVariable(rewriter, loc, normalizedIv: loops.getInductionVars()[i], origLb: lb,
1088 origStep: step);
1089 }
1090
1091 // Combine iteration spaces.
1092 SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1093 auto cst0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1094 auto cst1 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
1095 for (auto &sortedDimension : sortedDimensions) {
1096 Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
1097 for (auto idx : sortedDimension) {
1098 newUpperBound = rewriter.create<arith::MulIOp>(
1099 location: loc, args&: newUpperBound, args&: normalizedUpperBounds[idx]);
1100 }
1101 lowerBounds.push_back(Elt: cst0);
1102 steps.push_back(Elt: cst1);
1103 upperBounds.push_back(Elt: newUpperBound);
1104 }
1105
1106 // Create new ParallelLoop with conversions to the original induction values.
1107 // The loop below uses divisions to get the relevant range of values in the
1108 // new induction value that represent each range of the original induction
1109 // value. The remainders then determine based on that range, which iteration
1110 // of the original induction value this represents. This is a normalized value
1111 // that is un-normalized already by the previous logic.
1112 auto newPloop = rewriter.create<scf::ParallelOp>(
1113 location: loc, args&: lowerBounds, args&: upperBounds, args&: steps,
1114 args: [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1115 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1116 Value previous = ploopIVs[i];
1117 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1118 // Iterate over all except the last induction value.
1119 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1120 unsigned idx = combinedDimensions[i][j];
1121
1122 // Determine the current induction value's current loop iteration
1123 Value iv = insideBuilder.create<arith::RemSIOp>(
1124 location: loc, args&: previous, args&: normalizedUpperBounds[idx]);
1125 replaceAllUsesInRegionWith(orig: loops.getBody()->getArgument(i: idx), replacement: iv,
1126 region&: loops.getRegion());
1127
1128 // Remove the effect of the current induction value to prepare for
1129 // the next value.
1130 previous = insideBuilder.create<arith::DivSIOp>(
1131 location: loc, args&: previous, args&: normalizedUpperBounds[idx]);
1132 }
1133
1134 // The final induction value is just the remaining value.
1135 unsigned idx = combinedDimensions[i][0];
1136 replaceAllUsesInRegionWith(orig: loops.getBody()->getArgument(i: idx),
1137 replacement: previous, region&: loops.getRegion());
1138 }
1139 });
1140
1141 // Replace the old loop with the new loop.
1142 loops.getBody()->back().erase();
1143 newPloop.getBody()->getOperations().splice(
1144 where: Block::iterator(newPloop.getBody()->back()),
1145 L2&: loops.getBody()->getOperations());
1146 loops.erase();
1147}
1148
1149// Hoist the ops within `outer` that appear before `inner`.
1150// Such ops include the ops that have been introduced by parametric tiling.
1151// Ops that come from triangular loops (i.e. that belong to the program slice
1152// rooted at `outer`) and ops that have side effects cannot be hoisted.
1153// Return failure when any op fails to hoist.
1154static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1155 SetVector<Operation *> forwardSlice;
1156 ForwardSliceOptions options;
1157 options.filter = [&inner](Operation *op) {
1158 return op != inner.getOperation();
1159 };
1160 getForwardSlice(root: outer.getInductionVar(), forwardSlice: &forwardSlice, options);
1161 LogicalResult status = success();
1162 SmallVector<Operation *, 8> toHoist;
1163 for (auto &op : outer.getBody()->without_terminator()) {
1164 // Stop when encountering the inner loop.
1165 if (&op == inner.getOperation())
1166 break;
1167 // Skip over non-hoistable ops.
1168 if (forwardSlice.count(key: &op) > 0) {
1169 status = failure();
1170 continue;
1171 }
1172 // Skip intermediate scf::ForOp, these are not considered a failure.
1173 if (isa<scf::ForOp>(Val: op))
1174 continue;
1175 // Skip other ops with regions.
1176 if (op.getNumRegions() > 0) {
1177 status = failure();
1178 continue;
1179 }
1180 // Skip if op has side effects.
1181 // TODO: loads to immutable memory regions are ok.
1182 if (!isMemoryEffectFree(op: &op)) {
1183 status = failure();
1184 continue;
1185 }
1186 toHoist.push_back(Elt: &op);
1187 }
1188 auto *outerForOp = outer.getOperation();
1189 for (auto *op : toHoist)
1190 op->moveBefore(existingOp: outerForOp);
1191 return status;
1192}
1193
1194// Traverse the interTile and intraTile loops and try to hoist ops such that
1195// bands of perfectly nested loops are isolated.
1196// Return failure if either perfect interTile or perfect intraTile bands cannot
1197// be formed.
1198static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1199 LogicalResult status = success();
1200 const Loops &interTile = tileLoops.first;
1201 const Loops &intraTile = tileLoops.second;
1202 auto size = interTile.size();
1203 assert(size == intraTile.size());
1204 if (size <= 1)
1205 return success();
1206 for (unsigned s = 1; s < size; ++s)
1207 status = succeeded(Result: status) ? hoistOpsBetween(outer: intraTile[0], inner: intraTile[s])
1208 : failure();
1209 for (unsigned s = 1; s < size; ++s)
1210 status = succeeded(Result: status) ? hoistOpsBetween(outer: interTile[0], inner: interTile[s])
1211 : failure();
1212 return status;
1213}
1214
1215/// Collect perfectly nested loops starting from `rootForOps`. Loops are
1216/// perfectly nested if each loop is the first and only non-terminator operation
1217/// in the parent loop. Collect at most `maxLoops` loops and append them to
1218/// `forOps`.
1219template <typename T>
1220static void getPerfectlyNestedLoopsImpl(
1221 SmallVectorImpl<T> &forOps, T rootForOp,
1222 unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1223 for (unsigned i = 0; i < maxLoops; ++i) {
1224 forOps.push_back(rootForOp);
1225 Block &body = rootForOp.getRegion().front();
1226 if (body.begin() != std::prev(x: body.end(), n: 2))
1227 return;
1228
1229 rootForOp = dyn_cast<T>(&body.front());
1230 if (!rootForOp)
1231 return;
1232 }
1233}
1234
1235static Loops stripmineSink(scf::ForOp forOp, Value factor,
1236 ArrayRef<scf::ForOp> targets) {
1237 auto originalStep = forOp.getStep();
1238 auto iv = forOp.getInductionVar();
1239
1240 OpBuilder b(forOp);
1241 forOp.setStep(b.create<arith::MulIOp>(location: forOp.getLoc(), args&: originalStep, args&: factor));
1242
1243 Loops innerLoops;
1244 for (auto t : targets) {
1245 // Save information for splicing ops out of t when done
1246 auto begin = t.getBody()->begin();
1247 auto nOps = t.getBody()->getOperations().size();
1248
1249 // Insert newForOp before the terminator of `t`.
1250 auto b = OpBuilder::atBlockTerminator(block: (t.getBody()));
1251 Value stepped = b.create<arith::AddIOp>(location: t.getLoc(), args&: iv, args: forOp.getStep());
1252 Value ub =
1253 b.create<arith::MinSIOp>(location: t.getLoc(), args: forOp.getUpperBound(), args&: stepped);
1254
1255 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1256 auto newForOp = b.create<scf::ForOp>(location: t.getLoc(), args&: iv, args&: ub, args&: originalStep);
1257 newForOp.getBody()->getOperations().splice(
1258 where: newForOp.getBody()->getOperations().begin(),
1259 L2&: t.getBody()->getOperations(), first: begin, last: std::next(x: begin, n: nOps - 1));
1260 replaceAllUsesInRegionWith(orig: iv, replacement: newForOp.getInductionVar(),
1261 region&: newForOp.getRegion());
1262
1263 innerLoops.push_back(Elt: newForOp);
1264 }
1265
1266 return innerLoops;
1267}
1268
1269// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1270// Returns the new for operation, nested immediately under `target`.
1271template <typename SizeType>
1272static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1273 scf::ForOp target) {
1274 // TODO: Use cheap structural assertions that targets are nested under
1275 // forOp and that targets are not nested under each other when DominanceInfo
1276 // exposes the capability. It seems overkill to construct a whole function
1277 // dominance tree at this point.
1278 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1279 assert(res.size() == 1 && "Expected 1 inner forOp");
1280 return res[0];
1281}
1282
1283SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
1284 ArrayRef<Value> sizes,
1285 ArrayRef<scf::ForOp> targets) {
1286 SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
1287 SmallVector<scf::ForOp, 8> currentTargets(targets);
1288 for (auto it : llvm::zip(t&: forOps, u&: sizes)) {
1289 auto step = stripmineSink(forOp: std::get<0>(t&: it), factor: std::get<1>(t&: it), targets: currentTargets);
1290 res.push_back(Elt: step);
1291 currentTargets = step;
1292 }
1293 return res;
1294}
1295
1296Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
1297 scf::ForOp target) {
1298 SmallVector<scf::ForOp, 8> res;
1299 for (auto loops : tile(forOps, sizes, targets: ArrayRef<scf::ForOp>(target)))
1300 res.push_back(Elt: llvm::getSingleElement(C&: loops));
1301 return res;
1302}
1303
1304Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1305 // Collect perfectly nested loops. If more size values provided than nested
1306 // loops available, truncate `sizes`.
1307 SmallVector<scf::ForOp, 4> forOps;
1308 forOps.reserve(N: sizes.size());
1309 getPerfectlyNestedLoopsImpl(forOps, rootForOp, maxLoops: sizes.size());
1310 if (forOps.size() < sizes.size())
1311 sizes = sizes.take_front(N: forOps.size());
1312
1313 return ::tile(forOps, sizes, target: forOps.back());
1314}
1315
1316void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
1317 scf::ForOp root) {
1318 getPerfectlyNestedLoopsImpl(forOps&: nestedLoops, rootForOp: root);
1319}
1320
1321TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
1322 ArrayRef<int64_t> sizes) {
1323 // Collect perfectly nested loops. If more size values provided than nested
1324 // loops available, truncate `sizes`.
1325 SmallVector<scf::ForOp, 4> forOps;
1326 forOps.reserve(N: sizes.size());
1327 getPerfectlyNestedLoopsImpl(forOps, rootForOp, maxLoops: sizes.size());
1328 if (forOps.size() < sizes.size())
1329 sizes = sizes.take_front(N: forOps.size());
1330
1331 // Compute the tile sizes such that i-th outer loop executes size[i]
1332 // iterations. Given that the loop current executes
1333 // numIterations = ceildiv((upperBound - lowerBound), step)
1334 // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1335 SmallVector<Value, 4> tileSizes;
1336 tileSizes.reserve(N: sizes.size());
1337 for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1338 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1339
1340 auto forOp = forOps[i];
1341 OpBuilder builder(forOp);
1342 auto loc = forOp.getLoc();
1343 Value diff = builder.create<arith::SubIOp>(location: loc, args: forOp.getUpperBound(),
1344 args: forOp.getLowerBound());
1345 Value numIterations = ceilDivPositive(builder, loc, dividend: diff, divisor: forOp.getStep());
1346 Value iterationsPerBlock =
1347 ceilDivPositive(builder, loc, dividend: numIterations, divisor: sizes[i]);
1348 tileSizes.push_back(Elt: iterationsPerBlock);
1349 }
1350
1351 // Call parametric tiling with the given sizes.
1352 auto intraTile = tile(forOps, sizes: tileSizes, target: forOps.back());
1353 TileLoops tileLoops = std::make_pair(x&: forOps, y&: intraTile);
1354
1355 // TODO: for now we just ignore the result of band isolation.
1356 // In the future, mapping decisions may be impacted by the ability to
1357 // isolate perfectly nested bands.
1358 (void)tryIsolateBands(tileLoops);
1359
1360 return tileLoops;
1361}
1362
1363scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1364 scf::ForallOp source,
1365 RewriterBase &rewriter) {
1366 unsigned numTargetOuts = target.getNumResults();
1367 unsigned numSourceOuts = source.getNumResults();
1368
1369 // Create fused shared_outs.
1370 SmallVector<Value> fusedOuts;
1371 llvm::append_range(C&: fusedOuts, R: target.getOutputs());
1372 llvm::append_range(C&: fusedOuts, R: source.getOutputs());
1373
1374 // Create a new scf.forall op after the source loop.
1375 rewriter.setInsertionPointAfter(source);
1376 scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
1377 location: source.getLoc(), args: source.getMixedLowerBound(), args: source.getMixedUpperBound(),
1378 args: source.getMixedStep(), args&: fusedOuts, args: source.getMapping());
1379
1380 // Map control operands.
1381 IRMapping mapping;
1382 mapping.map(from: target.getInductionVars(), to: fusedLoop.getInductionVars());
1383 mapping.map(from: source.getInductionVars(), to: fusedLoop.getInductionVars());
1384
1385 // Map shared outs.
1386 mapping.map(from: target.getRegionIterArgs(),
1387 to: fusedLoop.getRegionIterArgs().take_front(N: numTargetOuts));
1388 mapping.map(from: source.getRegionIterArgs(),
1389 to: fusedLoop.getRegionIterArgs().take_back(N: numSourceOuts));
1390
1391 // Append everything except the terminator into the fused operation.
1392 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1393 for (Operation &op : target.getBody()->without_terminator())
1394 rewriter.clone(op, mapper&: mapping);
1395 for (Operation &op : source.getBody()->without_terminator())
1396 rewriter.clone(op, mapper&: mapping);
1397
1398 // Fuse the old terminator in_parallel ops into the new one.
1399 scf::InParallelOp targetTerm = target.getTerminator();
1400 scf::InParallelOp sourceTerm = source.getTerminator();
1401 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1402 rewriter.setInsertionPointToStart(fusedTerm.getBody());
1403 for (Operation &op : targetTerm.getYieldingOps())
1404 rewriter.clone(op, mapper&: mapping);
1405 for (Operation &op : sourceTerm.getYieldingOps())
1406 rewriter.clone(op, mapper&: mapping);
1407
1408 // Replace old loops by substituting their uses by results of the fused loop.
1409 rewriter.replaceOp(op: target, newValues: fusedLoop.getResults().take_front(n: numTargetOuts));
1410 rewriter.replaceOp(op: source, newValues: fusedLoop.getResults().take_back(n: numSourceOuts));
1411
1412 return fusedLoop;
1413}
1414
1415scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1416 scf::ForOp source,
1417 RewriterBase &rewriter) {
1418 unsigned numTargetOuts = target.getNumResults();
1419 unsigned numSourceOuts = source.getNumResults();
1420
1421 // Create fused init_args, with target's init_args before source's init_args.
1422 SmallVector<Value> fusedInitArgs;
1423 llvm::append_range(C&: fusedInitArgs, R: target.getInitArgs());
1424 llvm::append_range(C&: fusedInitArgs, R: source.getInitArgs());
1425
1426 // Create a new scf.for op after the source loop (with scf.yield terminator
1427 // (without arguments) only in case its init_args is empty).
1428 rewriter.setInsertionPointAfter(source);
1429 scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
1430 location: source.getLoc(), args: source.getLowerBound(), args: source.getUpperBound(),
1431 args: source.getStep(), args&: fusedInitArgs);
1432
1433 // Map original induction variables and operands to those of the fused loop.
1434 IRMapping mapping;
1435 mapping.map(from: target.getInductionVar(), to: fusedLoop.getInductionVar());
1436 mapping.map(from: target.getRegionIterArgs(),
1437 to: fusedLoop.getRegionIterArgs().take_front(N: numTargetOuts));
1438 mapping.map(from: source.getInductionVar(), to: fusedLoop.getInductionVar());
1439 mapping.map(from: source.getRegionIterArgs(),
1440 to: fusedLoop.getRegionIterArgs().take_back(N: numSourceOuts));
1441
1442 // Merge target's body into the new (fused) for loop and then source's body.
1443 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1444 for (Operation &op : target.getBody()->without_terminator())
1445 rewriter.clone(op, mapper&: mapping);
1446 for (Operation &op : source.getBody()->without_terminator())
1447 rewriter.clone(op, mapper&: mapping);
1448
1449 // Build fused yield results by appropriately mapping original yield operands.
1450 SmallVector<Value> yieldResults;
1451 for (Value operand : target.getBody()->getTerminator()->getOperands())
1452 yieldResults.push_back(Elt: mapping.lookupOrDefault(from: operand));
1453 for (Value operand : source.getBody()->getTerminator()->getOperands())
1454 yieldResults.push_back(Elt: mapping.lookupOrDefault(from: operand));
1455 if (!yieldResults.empty())
1456 rewriter.create<scf::YieldOp>(location: source.getLoc(), args&: yieldResults);
1457
1458 // Replace old loops by substituting their uses by results of the fused loop.
1459 rewriter.replaceOp(op: target, newValues: fusedLoop.getResults().take_front(n: numTargetOuts));
1460 rewriter.replaceOp(op: source, newValues: fusedLoop.getResults().take_back(n: numSourceOuts));
1461
1462 return fusedLoop;
1463}
1464
1465FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1466 scf::ForallOp forallOp) {
1467 SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1468 SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1469 SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1470
1471 if (forallOp.isNormalized())
1472 return forallOp;
1473
1474 OpBuilder::InsertionGuard g(rewriter);
1475 auto loc = forallOp.getLoc();
1476 rewriter.setInsertionPoint(forallOp);
1477 SmallVector<OpFoldResult> newUbs;
1478 for (auto [lb, ub, step] : llvm::zip_equal(t&: lbs, u&: ubs, args&: steps)) {
1479 Range normalizedLoopParams =
1480 emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1481 newUbs.push_back(Elt: normalizedLoopParams.size);
1482 }
1483 (void)foldDynamicIndexList(ofrs&: newUbs);
1484
1485 // Use the normalized builder since the lower bounds are always 0 and the
1486 // steps are always 1.
1487 auto normalizedForallOp = rewriter.create<scf::ForallOp>(
1488 location: loc, args&: newUbs, args: forallOp.getOutputs(), args: forallOp.getMapping(),
1489 args: [](OpBuilder &, Location, ValueRange) {});
1490
1491 rewriter.inlineRegionBefore(region&: forallOp.getBodyRegion(),
1492 parent&: normalizedForallOp.getBodyRegion(),
1493 before: normalizedForallOp.getBodyRegion().begin());
1494 // Remove the original empty block in the new loop.
1495 rewriter.eraseBlock(block: &normalizedForallOp.getBodyRegion().back());
1496
1497 rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1498 // Update the users of the original loop variables.
1499 for (auto [idx, iv] :
1500 llvm::enumerate(First: normalizedForallOp.getInductionVars())) {
1501 auto origLb = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: lbs[idx]);
1502 auto origStep = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: steps[idx]);
1503 denormalizeInductionVariable(rewriter, loc, normalizedIv: iv, origLb, origStep);
1504 }
1505
1506 rewriter.replaceOp(op: forallOp, newOp: normalizedForallOp);
1507 return normalizedForallOp;
1508}
1509

source code of mlir/lib/Dialect/SCF/Utils/Utils.cpp