1//===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous loop transformation routines.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Utils/Utils.h"
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Interfaces/SideEffectInterfaces.h"
23#include "mlir/Support/MathExtras.h"
24#include "mlir/Transforms/RegionUtils.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/SmallPtrSet.h"
28#include "llvm/ADT/SmallVector.h"
29
30using namespace mlir;
31
32namespace {
33// This structure is to pass and return sets of loop parameters without
34// confusing the order.
35struct LoopParams {
36 Value lowerBound;
37 Value upperBound;
38 Value step;
39};
40} // namespace
41
42SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
43 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
44 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
45 bool replaceIterOperandsUsesInLoop) {
46 if (loopNest.empty())
47 return {};
48 // This method is recursive (to make it more readable). Adding an
49 // assertion here to limit the recursion. (See
50 // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
51 assert(loopNest.size() <= 10 &&
52 "exceeded recursion limit when yielding value from loop nest");
53
54 // To yield a value from a perfectly nested loop nest, the following
55 // pattern needs to be created, i.e. starting with
56 //
57 // ```mlir
58 // scf.for .. {
59 // scf.for .. {
60 // scf.for .. {
61 // %value = ...
62 // }
63 // }
64 // }
65 // ```
66 //
67 // needs to be modified to
68 //
69 // ```mlir
70 // %0 = scf.for .. iter_args(%arg0 = %init) {
71 // %1 = scf.for .. iter_args(%arg1 = %arg0) {
72 // %2 = scf.for .. iter_args(%arg2 = %arg1) {
73 // %value = ...
74 // scf.yield %value
75 // }
76 // scf.yield %2
77 // }
78 // scf.yield %1
79 // }
80 // ```
81 //
82 // The inner most loop is handled using the `replaceWithAdditionalYields`
83 // that works on a single loop.
84 if (loopNest.size() == 1) {
85 auto innerMostLoop =
86 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
87 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
88 newYieldValuesFn));
89 return {innerMostLoop};
90 }
91 // The outer loops are modified by calling this method recursively
92 // - The return value of the inner loop is the value yielded by this loop.
93 // - The region iter args of this loop are the init_args for the inner loop.
94 SmallVector<scf::ForOp> newLoopNest;
95 NewYieldValuesFn fn =
96 [&](OpBuilder &innerBuilder, Location loc,
97 ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
98 newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
99 innerNewBBArgs, newYieldValuesFn,
100 replaceIterOperandsUsesInLoop);
101 return llvm::to_vector(llvm::map_range(
102 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
103 [](OpResult r) -> Value { return r; }));
104 };
105 scf::ForOp outerMostLoop =
106 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
107 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
108 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
109 return newLoopNest;
110}
111
112/// Outline a region with a single block into a new FuncOp.
113/// Assumes the FuncOp result types is the type of the yielded operands of the
114/// single block. This constraint makes it easy to determine the result.
115/// This method also clones the `arith::ConstantIndexOp` at the start of
116/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
117/// provided, it will be set to point to the operation that calls the outlined
118/// function.
119// TODO: support more than single-block regions.
120// TODO: more flexible constant handling.
121FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
122 Location loc,
123 Region &region,
124 StringRef funcName,
125 func::CallOp *callOp) {
126 assert(!funcName.empty() && "funcName cannot be empty");
127 if (!region.hasOneBlock())
128 return failure();
129
130 Block *originalBlock = &region.front();
131 Operation *originalTerminator = originalBlock->getTerminator();
132
133 // Outline before current function.
134 OpBuilder::InsertionGuard g(rewriter);
135 rewriter.setInsertionPoint(region.getParentOfType<func::FuncOp>());
136
137 SetVector<Value> captures;
138 getUsedValuesDefinedAbove(regions: region, values&: captures);
139
140 ValueRange outlinedValues(captures.getArrayRef());
141 SmallVector<Type> outlinedFuncArgTypes;
142 SmallVector<Location> outlinedFuncArgLocs;
143 // Region's arguments are exactly the first block's arguments as per
144 // Region::getArguments().
145 // Func's arguments are cat(regions's arguments, captures arguments).
146 for (BlockArgument arg : region.getArguments()) {
147 outlinedFuncArgTypes.push_back(Elt: arg.getType());
148 outlinedFuncArgLocs.push_back(Elt: arg.getLoc());
149 }
150 for (Value value : outlinedValues) {
151 outlinedFuncArgTypes.push_back(Elt: value.getType());
152 outlinedFuncArgLocs.push_back(Elt: value.getLoc());
153 }
154 FunctionType outlinedFuncType =
155 FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
156 originalTerminator->getOperandTypes());
157 auto outlinedFunc =
158 rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
159 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
160
161 // Merge blocks while replacing the original block operands.
162 // Warning: `mergeBlocks` erases the original block, reconstruct it later.
163 int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
164 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
165 {
166 OpBuilder::InsertionGuard g(rewriter);
167 rewriter.setInsertionPointToEnd(outlinedFuncBody);
168 rewriter.mergeBlocks(
169 source: originalBlock, dest: outlinedFuncBody,
170 argValues: outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
171 // Explicitly set up a new ReturnOp terminator.
172 rewriter.setInsertionPointToEnd(outlinedFuncBody);
173 rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
174 originalTerminator->getOperands());
175 }
176
177 // Reconstruct the block that was deleted and add a
178 // terminator(call_results).
179 Block *newBlock = rewriter.createBlock(
180 parent: &region, insertPt: region.begin(),
181 argTypes: TypeRange{outlinedFuncArgTypes}.take_front(n: numOriginalBlockArguments),
182 locs: ArrayRef<Location>(outlinedFuncArgLocs)
183 .take_front(N: numOriginalBlockArguments));
184 {
185 OpBuilder::InsertionGuard g(rewriter);
186 rewriter.setInsertionPointToEnd(newBlock);
187 SmallVector<Value> callValues;
188 llvm::append_range(C&: callValues, R: newBlock->getArguments());
189 llvm::append_range(C&: callValues, R&: outlinedValues);
190 auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
191 if (callOp)
192 *callOp = call;
193
194 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
195 // Clone `originalTerminator` to take the callOp results then erase it from
196 // `outlinedFuncBody`.
197 IRMapping bvm;
198 bvm.map(originalTerminator->getOperands(), call->getResults());
199 rewriter.clone(op&: *originalTerminator, mapper&: bvm);
200 rewriter.eraseOp(op: originalTerminator);
201 }
202
203 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
204 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
205 for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
206 outlinedValues.size()))) {
207 Value orig = std::get<0>(it);
208 Value repl = std::get<1>(it);
209 {
210 OpBuilder::InsertionGuard g(rewriter);
211 rewriter.setInsertionPointToStart(outlinedFuncBody);
212 if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
213 IRMapping bvm;
214 repl = rewriter.clone(*cst, bvm)->getResult(0);
215 }
216 }
217 orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
218 return outlinedFunc->isProperAncestor(opOperand.getOwner());
219 });
220 }
221
222 return outlinedFunc;
223}
224
225LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
226 func::FuncOp *thenFn, StringRef thenFnName,
227 func::FuncOp *elseFn, StringRef elseFnName) {
228 IRRewriter rewriter(b);
229 Location loc = ifOp.getLoc();
230 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
231 if (thenFn && !ifOp.getThenRegion().empty()) {
232 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
233 rewriter, loc, ifOp.getThenRegion(), thenFnName);
234 if (failed(result: outlinedFuncOpOrFailure))
235 return failure();
236 *thenFn = *outlinedFuncOpOrFailure;
237 }
238 if (elseFn && !ifOp.getElseRegion().empty()) {
239 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
240 rewriter, loc, ifOp.getElseRegion(), elseFnName);
241 if (failed(result: outlinedFuncOpOrFailure))
242 return failure();
243 *elseFn = *outlinedFuncOpOrFailure;
244 }
245 return success();
246}
247
248bool mlir::getInnermostParallelLoops(Operation *rootOp,
249 SmallVectorImpl<scf::ParallelOp> &result) {
250 assert(rootOp != nullptr && "Root operation must not be a nullptr.");
251 bool rootEnclosesPloops = false;
252 for (Region &region : rootOp->getRegions()) {
253 for (Block &block : region.getBlocks()) {
254 for (Operation &op : block) {
255 bool enclosesPloops = getInnermostParallelLoops(&op, result);
256 rootEnclosesPloops |= enclosesPloops;
257 if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
258 rootEnclosesPloops = true;
259
260 // Collect parallel loop if it is an innermost one.
261 if (!enclosesPloops)
262 result.push_back(ploop);
263 }
264 }
265 }
266 }
267 return rootEnclosesPloops;
268}
269
270// Build the IR that performs ceil division of a positive value by a constant:
271// ceildiv(a, B) = divis(a + (B-1), B)
272// where divis is rounding-to-zero division.
273static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
274 int64_t divisor) {
275 assert(divisor > 0 && "expected positive divisor");
276 assert(dividend.getType().isIndex() && "expected index-typed value");
277
278 Value divisorMinusOneCst =
279 builder.create<arith::ConstantIndexOp>(location: loc, args: divisor - 1);
280 Value divisorCst = builder.create<arith::ConstantIndexOp>(location: loc, args&: divisor);
281 Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
282 return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
283}
284
285// Build the IR that performs ceil division of a positive value by another
286// positive value:
287// ceildiv(a, b) = divis(a + (b - 1), b)
288// where divis is rounding-to-zero division.
289static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
290 Value divisor) {
291 assert(dividend.getType().isIndex() && "expected index-typed value");
292
293 Value cstOne = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
294 Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
295 Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
296 return builder.create<arith::DivUIOp>(loc, sum, divisor);
297}
298
299/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
300/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
301/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
302/// unrolled iteration using annotateFn.
303static void generateUnrolledLoop(
304 Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
305 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
306 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
307 ValueRange iterArgs, ValueRange yieldedValues) {
308 // Builder to insert unrolled bodies just before the terminator of the body of
309 // 'forOp'.
310 auto builder = OpBuilder::atBlockTerminator(block: loopBodyBlock);
311
312 if (!annotateFn)
313 annotateFn = [](unsigned, Operation *, OpBuilder) {};
314
315 // Keep a pointer to the last non-terminator operation in the original block
316 // so that we know what to clone (since we are doing this in-place).
317 Block::iterator srcBlockEnd = std::prev(x: loopBodyBlock->end(), n: 2);
318
319 // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
320 SmallVector<Value, 4> lastYielded(yieldedValues);
321
322 for (unsigned i = 1; i < unrollFactor; i++) {
323 IRMapping operandMap;
324
325 // Prepare operand map.
326 operandMap.map(from&: iterArgs, to&: lastYielded);
327
328 // If the induction variable is used, create a remapping to the value for
329 // this unrolled instance.
330 if (!forOpIV.use_empty()) {
331 Value ivUnroll = ivRemapFn(i, forOpIV, builder);
332 operandMap.map(from: forOpIV, to: ivUnroll);
333 }
334
335 // Clone the original body of 'forOp'.
336 for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++) {
337 Operation *clonedOp = builder.clone(op&: *it, mapper&: operandMap);
338 annotateFn(i, clonedOp, builder);
339 }
340
341 // Update yielded values.
342 for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
343 lastYielded[i] = operandMap.lookup(from: yieldedValues[i]);
344 }
345
346 // Make sure we annotate the Ops in the original body. We do this last so that
347 // any annotations are not copied into the cloned Ops above.
348 for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++)
349 annotateFn(0, &*it, builder);
350
351 // Update operands of the yield statement.
352 loopBodyBlock->getTerminator()->setOperands(lastYielded);
353}
354
355/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
356LogicalResult mlir::loopUnrollByFactor(
357 scf::ForOp forOp, uint64_t unrollFactor,
358 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
359 assert(unrollFactor > 0 && "expected positive unroll factor");
360
361 // Return if the loop body is empty.
362 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
363 return success();
364
365 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
366 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
367 OpBuilder boundsBuilder(forOp);
368 IRRewriter rewriter(forOp.getContext());
369 auto loc = forOp.getLoc();
370 Value step = forOp.getStep();
371 Value upperBoundUnrolled;
372 Value stepUnrolled;
373 bool generateEpilogueLoop = true;
374
375 std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
376 std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
377 std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
378 if (lbCstOp && ubCstOp && stepCstOp) {
379 // Constant loop bounds computation.
380 int64_t lbCst = lbCstOp.value();
381 int64_t ubCst = ubCstOp.value();
382 int64_t stepCst = stepCstOp.value();
383 assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
384 "expected positive loop bounds and step");
385 int64_t tripCount = mlir::ceilDiv(lhs: ubCst - lbCst, rhs: stepCst);
386
387 if (unrollFactor == 1) {
388 if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
389 return failure();
390 return success();
391 }
392
393 int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
394 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
395 int64_t stepUnrolledCst = stepCst * unrollFactor;
396
397 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
398 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
399 if (generateEpilogueLoop)
400 upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
401 loc, upperBoundUnrolledCst);
402 else
403 upperBoundUnrolled = forOp.getUpperBound();
404
405 // Create constant for 'stepUnrolled'.
406 stepUnrolled = stepCst == stepUnrolledCst
407 ? step
408 : boundsBuilder.create<arith::ConstantIndexOp>(
409 loc, stepUnrolledCst);
410 } else {
411 // Dynamic loop bounds computation.
412 // TODO: Add dynamic asserts for negative lb/ub/step, or
413 // consider using ceilDiv from AffineApplyExpander.
414 auto lowerBound = forOp.getLowerBound();
415 auto upperBound = forOp.getUpperBound();
416 Value diff =
417 boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
418 Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
419 Value unrollFactorCst =
420 boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
421 Value tripCountRem =
422 boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
423 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
424 Value tripCountEvenMultiple =
425 boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
426 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
427 upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
428 loc, lowerBound,
429 boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
430 // Scale 'step' by 'unrollFactor'.
431 stepUnrolled =
432 boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
433 }
434
435 // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
436 if (generateEpilogueLoop) {
437 OpBuilder epilogueBuilder(forOp->getContext());
438 epilogueBuilder.setInsertionPoint(forOp->getBlock(),
439 std::next(x: Block::iterator(forOp)));
440 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
441 epilogueForOp.setLowerBound(upperBoundUnrolled);
442
443 // Update uses of loop results.
444 auto results = forOp.getResults();
445 auto epilogueResults = epilogueForOp.getResults();
446
447 for (auto e : llvm::zip(results, epilogueResults)) {
448 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
449 }
450 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
451 epilogueForOp.getInitArgs().size(), results);
452 (void)epilogueForOp.promoteIfSingleIteration(rewriter);
453 }
454
455 // Create unrolled loop.
456 forOp.setUpperBound(upperBoundUnrolled);
457 forOp.setStep(stepUnrolled);
458
459 auto iterArgs = ValueRange(forOp.getRegionIterArgs());
460 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
461
462 generateUnrolledLoop(
463 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
464 [&](unsigned i, Value iv, OpBuilder b) {
465 // iv' = iv + step * i;
466 auto stride = b.create<arith::MulIOp>(
467 loc, step, b.create<arith::ConstantIndexOp>(loc, i));
468 return b.create<arith::AddIOp>(loc, iv, stride);
469 },
470 annotateFn, iterArgs, yieldedValues);
471 // Promote the loop body up if this has turned into a single iteration loop.
472 (void)forOp.promoteIfSingleIteration(rewriter);
473 return success();
474}
475
476/// Transform a loop with a strictly positive step
477/// for %i = %lb to %ub step %s
478/// into a 0-based loop with step 1
479/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
480/// %i = %ii * %s + %lb
481/// Insert the induction variable remapping in the body of `inner`, which is
482/// expected to be either `loop` or another loop perfectly nested under `loop`.
483/// Insert the definition of new bounds immediate before `outer`, which is
484/// expected to be either `loop` or its parent in the loop nest.
485static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
486 Value lb, Value ub, Value step) {
487 // For non-index types, generate `arith` instructions
488 // Check if the loop is already known to have a constant zero lower bound or
489 // a constant one step.
490 bool isZeroBased = false;
491 if (auto lbCst = getConstantIntValue(ofr: lb))
492 isZeroBased = lbCst.value() == 0;
493
494 bool isStepOne = false;
495 if (auto stepCst = getConstantIntValue(ofr: step))
496 isStepOne = stepCst.value() == 1;
497
498 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
499 // assuming the step is strictly positive. Update the bounds and the step
500 // of the loop to go from 0 to the number of iterations, if necessary.
501 if (isZeroBased && isStepOne)
502 return {.lowerBound: lb, .upperBound: ub, .step: step};
503
504 Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
505 Value newUpperBound =
506 isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
507
508 Value newLowerBound = isZeroBased
509 ? lb
510 : rewriter.create<arith::ConstantOp>(
511 loc, rewriter.getZeroAttr(lb.getType()));
512 Value newStep = isStepOne
513 ? step
514 : rewriter.create<arith::ConstantOp>(
515 loc, rewriter.getIntegerAttr(step.getType(), 1));
516
517 return {.lowerBound: newLowerBound, .upperBound: newUpperBound, .step: newStep};
518}
519
520/// Get back the original induction variable values after loop normalization
521static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
522 Value normalizedIv, Value origLb,
523 Value origStep) {
524 Value denormalizedIv;
525 SmallPtrSet<Operation *, 2> preserve;
526 bool isStepOne = isConstantIntValue(ofr: origStep, value: 1);
527 bool isZeroBased = isConstantIntValue(ofr: origLb, value: 0);
528
529 Value scaled = normalizedIv;
530 if (!isStepOne) {
531 scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
532 preserve.insert(Ptr: scaled.getDefiningOp());
533 }
534 denormalizedIv = scaled;
535 if (!isZeroBased) {
536 denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
537 preserve.insert(Ptr: denormalizedIv.getDefiningOp());
538 }
539
540 rewriter.replaceAllUsesExcept(from: normalizedIv, to: denormalizedIv, preservedUsers: preserve);
541}
542
543/// Helper function to multiply a sequence of values.
544static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
545 ArrayRef<Value> values) {
546 assert(!values.empty() && "unexpected empty list");
547 Value productOf = values.front();
548 for (auto v : values.drop_front()) {
549 productOf = rewriter.create<arith::MulIOp>(loc, productOf, v);
550 }
551 return productOf;
552}
553
554/// For each original loop, the value of the
555/// induction variable can be obtained by dividing the induction variable of
556/// the linearized loop by the total number of iterations of the loops nested
557/// in it modulo the number of iterations in this loop (remove the values
558/// related to the outer loops):
559/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
560/// Compute these iteratively from the innermost loop by creating a "running
561/// quotient" of division by the range.
562static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
563delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
564 Value linearizedIv, ArrayRef<Value> ubs) {
565 Value previous = linearizedIv;
566 SmallVector<Value> delinearizedIvs(ubs.size());
567 SmallPtrSet<Operation *, 2> preservedUsers;
568 for (unsigned i = 0, e = ubs.size(); i < e; ++i) {
569 unsigned idx = ubs.size() - i - 1;
570 if (i != 0) {
571 previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
572 preservedUsers.insert(Ptr: previous.getDefiningOp());
573 }
574 Value iv = previous;
575 if (i != e - 1) {
576 iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
577 preservedUsers.insert(Ptr: iv.getDefiningOp());
578 }
579 delinearizedIvs[idx] = iv;
580 }
581 return {delinearizedIvs, preservedUsers};
582}
583
584LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
585 MutableArrayRef<scf::ForOp> loops) {
586 if (loops.size() < 2)
587 return failure();
588
589 scf::ForOp innermost = loops.back();
590 scf::ForOp outermost = loops.front();
591
592 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
593 // allows the following code to assume upperBound is the number of iterations.
594 for (auto loop : loops) {
595 OpBuilder::InsertionGuard g(rewriter);
596 rewriter.setInsertionPoint(outermost);
597 Value lb = loop.getLowerBound();
598 Value ub = loop.getUpperBound();
599 Value step = loop.getStep();
600 auto newLoopParams =
601 emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
602
603 rewriter.modifyOpInPlace(loop, [&]() {
604 loop.setLowerBound(newLoopParams.lowerBound);
605 loop.setUpperBound(newLoopParams.upperBound);
606 loop.setStep(newLoopParams.step);
607 });
608
609 rewriter.setInsertionPointToStart(innermost.getBody());
610 denormalizeInductionVariable(rewriter, loop.getLoc(),
611 loop.getInductionVar(), lb, step);
612 }
613
614 // 2. Emit code computing the upper bound of the coalesced loop as product
615 // of the number of iterations of all loops.
616 OpBuilder::InsertionGuard g(rewriter);
617 rewriter.setInsertionPoint(outermost);
618 Location loc = outermost.getLoc();
619 SmallVector<Value> upperBounds = llvm::map_to_vector(
620 loops, [](auto loop) { return loop.getUpperBound(); });
621 Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, values: upperBounds);
622 outermost.setUpperBound(upperBound);
623
624 rewriter.setInsertionPointToStart(innermost.getBody());
625 auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
626 rewriter, loc, outermost.getInductionVar(), upperBounds);
627 rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
628 preservedUsers);
629
630 for (int i = loops.size() - 1; i > 0; --i) {
631 auto outerLoop = loops[i - 1];
632 auto innerLoop = loops[i];
633
634 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
635 auto yieldedVals = llvm::to_vector(Range: innerTerminator->getOperands());
636 rewriter.eraseOp(op: innerTerminator);
637
638 SmallVector<Value> innerBlockArgs;
639 innerBlockArgs.push_back(Elt: delinearizeIvs[i]);
640 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
641 rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
642 Block::iterator(innerLoop), innerBlockArgs);
643 rewriter.replaceOp(innerLoop, yieldedVals);
644 }
645 return success();
646}
647
648LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
649 if (loops.empty()) {
650 return failure();
651 }
652 IRRewriter rewriter(loops.front().getContext());
653 return coalesceLoops(rewriter, loops);
654}
655
656LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
657 LogicalResult result(failure());
658 SmallVector<scf::ForOp> loops;
659 getPerfectlyNestedLoops(loops, op);
660
661 // Look for a band of loops that can be coalesced, i.e. perfectly nested
662 // loops with bounds defined above some loop.
663
664 // 1. For each loop, find above which parent loop its bounds operands are
665 // defined.
666 SmallVector<unsigned> operandsDefinedAbove(loops.size());
667 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
668 operandsDefinedAbove[i] = i;
669 for (unsigned j = 0; j < i; ++j) {
670 SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
671 loops[i].getUpperBound(),
672 loops[i].getStep()};
673 if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
674 operandsDefinedAbove[i] = j;
675 break;
676 }
677 }
678 }
679
680 // 2. For each inner loop check that the iter_args for the immediately outer
681 // loop are the init for the immediately inner loop and that the yields of the
682 // return of the inner loop is the yield for the immediately outer loop. Keep
683 // track of where the chain starts from for each loop.
684 SmallVector<unsigned> iterArgChainStart(loops.size());
685 iterArgChainStart[0] = 0;
686 for (unsigned i = 1, e = loops.size(); i < e; ++i) {
687 // By default set the start of the chain to itself.
688 iterArgChainStart[i] = i;
689 auto outerloop = loops[i - 1];
690 auto innerLoop = loops[i];
691 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
692 continue;
693 }
694 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
695 continue;
696 }
697 auto outerloopTerminator = outerloop.getBody()->getTerminator();
698 if (!llvm::equal(outerloopTerminator->getOperands(),
699 innerLoop.getResults())) {
700 continue;
701 }
702 iterArgChainStart[i] = iterArgChainStart[i - 1];
703 }
704
705 // 3. Identify bands of loops such that the operands of all of them are
706 // defined above the first loop in the band. Traverse the nest bottom-up
707 // so that modifications don't invalidate the inner loops.
708 for (unsigned end = loops.size(); end > 0; --end) {
709 unsigned start = 0;
710 for (; start < end - 1; ++start) {
711 auto maxPos =
712 *std::max_element(first: std::next(x: operandsDefinedAbove.begin(), n: start),
713 last: std::next(x: operandsDefinedAbove.begin(), n: end));
714 if (maxPos > start)
715 continue;
716 if (iterArgChainStart[end - 1] > start)
717 continue;
718 auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
719 if (succeeded(coalesceLoops(band)))
720 result = success();
721 break;
722 }
723 // If a band was found and transformed, keep looking at the loops above
724 // the outermost transformed loop.
725 if (start != end - 1)
726 end = start + 1;
727 }
728 return result;
729}
730
731void mlir::collapseParallelLoops(
732 RewriterBase &rewriter, scf::ParallelOp loops,
733 ArrayRef<std::vector<unsigned>> combinedDimensions) {
734 OpBuilder::InsertionGuard g(rewriter);
735 rewriter.setInsertionPoint(loops);
736 Location loc = loops.getLoc();
737
738 // Presort combined dimensions.
739 auto sortedDimensions = llvm::to_vector<3>(Range&: combinedDimensions);
740 for (auto &dims : sortedDimensions)
741 llvm::sort(C&: dims);
742
743 // Normalize ParallelOp's iteration pattern.
744 SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
745 normalizedUpperBounds;
746 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
747 OpBuilder::InsertionGuard g2(rewriter);
748 rewriter.setInsertionPoint(loops);
749 Value lb = loops.getLowerBound()[i];
750 Value ub = loops.getUpperBound()[i];
751 Value step = loops.getStep()[i];
752 auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
753 normalizedLowerBounds.push_back(Elt: newLoopParams.lowerBound);
754 normalizedUpperBounds.push_back(Elt: newLoopParams.upperBound);
755 normalizedSteps.push_back(Elt: newLoopParams.step);
756
757 rewriter.setInsertionPointToStart(loops.getBody());
758 denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
759 step);
760 }
761
762 // Combine iteration spaces.
763 SmallVector<Value, 3> lowerBounds, upperBounds, steps;
764 auto cst0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
765 auto cst1 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
766 for (auto &sortedDimension : sortedDimensions) {
767 Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
768 for (auto idx : sortedDimension) {
769 newUpperBound = rewriter.create<arith::MulIOp>(
770 loc, newUpperBound, normalizedUpperBounds[idx]);
771 }
772 lowerBounds.push_back(Elt: cst0);
773 steps.push_back(Elt: cst1);
774 upperBounds.push_back(Elt: newUpperBound);
775 }
776
777 // Create new ParallelLoop with conversions to the original induction values.
778 // The loop below uses divisions to get the relevant range of values in the
779 // new induction value that represent each range of the original induction
780 // value. The remainders then determine based on that range, which iteration
781 // of the original induction value this represents. This is a normalized value
782 // that is un-normalized already by the previous logic.
783 auto newPloop = rewriter.create<scf::ParallelOp>(
784 loc, lowerBounds, upperBounds, steps,
785 [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
786 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
787 Value previous = ploopIVs[i];
788 unsigned numberCombinedDimensions = combinedDimensions[i].size();
789 // Iterate over all except the last induction value.
790 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
791 unsigned idx = combinedDimensions[i][j];
792
793 // Determine the current induction value's current loop iteration
794 Value iv = insideBuilder.create<arith::RemSIOp>(
795 loc, previous, normalizedUpperBounds[idx]);
796 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
797 loops.getRegion());
798
799 // Remove the effect of the current induction value to prepare for
800 // the next value.
801 previous = insideBuilder.create<arith::DivSIOp>(
802 loc, previous, normalizedUpperBounds[idx]);
803 }
804
805 // The final induction value is just the remaining value.
806 unsigned idx = combinedDimensions[i][0];
807 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
808 previous, loops.getRegion());
809 }
810 });
811
812 // Replace the old loop with the new loop.
813 loops.getBody()->back().erase();
814 newPloop.getBody()->getOperations().splice(
815 Block::iterator(newPloop.getBody()->back()),
816 loops.getBody()->getOperations());
817 loops.erase();
818}
819
820// Hoist the ops within `outer` that appear before `inner`.
821// Such ops include the ops that have been introduced by parametric tiling.
822// Ops that come from triangular loops (i.e. that belong to the program slice
823// rooted at `outer`) and ops that have side effects cannot be hoisted.
824// Return failure when any op fails to hoist.
825static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
826 SetVector<Operation *> forwardSlice;
827 ForwardSliceOptions options;
828 options.filter = [&inner](Operation *op) {
829 return op != inner.getOperation();
830 };
831 getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
832 LogicalResult status = success();
833 SmallVector<Operation *, 8> toHoist;
834 for (auto &op : outer.getBody()->without_terminator()) {
835 // Stop when encountering the inner loop.
836 if (&op == inner.getOperation())
837 break;
838 // Skip over non-hoistable ops.
839 if (forwardSlice.count(&op) > 0) {
840 status = failure();
841 continue;
842 }
843 // Skip intermediate scf::ForOp, these are not considered a failure.
844 if (isa<scf::ForOp>(op))
845 continue;
846 // Skip other ops with regions.
847 if (op.getNumRegions() > 0) {
848 status = failure();
849 continue;
850 }
851 // Skip if op has side effects.
852 // TODO: loads to immutable memory regions are ok.
853 if (!isMemoryEffectFree(&op)) {
854 status = failure();
855 continue;
856 }
857 toHoist.push_back(&op);
858 }
859 auto *outerForOp = outer.getOperation();
860 for (auto *op : toHoist)
861 op->moveBefore(outerForOp);
862 return status;
863}
864
865// Traverse the interTile and intraTile loops and try to hoist ops such that
866// bands of perfectly nested loops are isolated.
867// Return failure if either perfect interTile or perfect intraTile bands cannot
868// be formed.
869static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
870 LogicalResult status = success();
871 const Loops &interTile = tileLoops.first;
872 const Loops &intraTile = tileLoops.second;
873 auto size = interTile.size();
874 assert(size == intraTile.size());
875 if (size <= 1)
876 return success();
877 for (unsigned s = 1; s < size; ++s)
878 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
879 : failure();
880 for (unsigned s = 1; s < size; ++s)
881 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
882 : failure();
883 return status;
884}
885
886/// Collect perfectly nested loops starting from `rootForOps`. Loops are
887/// perfectly nested if each loop is the first and only non-terminator operation
888/// in the parent loop. Collect at most `maxLoops` loops and append them to
889/// `forOps`.
890template <typename T>
891static void getPerfectlyNestedLoopsImpl(
892 SmallVectorImpl<T> &forOps, T rootForOp,
893 unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
894 for (unsigned i = 0; i < maxLoops; ++i) {
895 forOps.push_back(rootForOp);
896 Block &body = rootForOp.getRegion().front();
897 if (body.begin() != std::prev(x: body.end(), n: 2))
898 return;
899
900 rootForOp = dyn_cast<T>(&body.front());
901 if (!rootForOp)
902 return;
903 }
904}
905
906static Loops stripmineSink(scf::ForOp forOp, Value factor,
907 ArrayRef<scf::ForOp> targets) {
908 auto originalStep = forOp.getStep();
909 auto iv = forOp.getInductionVar();
910
911 OpBuilder b(forOp);
912 forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
913
914 Loops innerLoops;
915 for (auto t : targets) {
916 // Save information for splicing ops out of t when done
917 auto begin = t.getBody()->begin();
918 auto nOps = t.getBody()->getOperations().size();
919
920 // Insert newForOp before the terminator of `t`.
921 auto b = OpBuilder::atBlockTerminator((t.getBody()));
922 Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
923 Value ub =
924 b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
925
926 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
927 auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
928 newForOp.getBody()->getOperations().splice(
929 newForOp.getBody()->getOperations().begin(),
930 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
931 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
932 newForOp.getRegion());
933
934 innerLoops.push_back(newForOp);
935 }
936
937 return innerLoops;
938}
939
940// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
941// Returns the new for operation, nested immediately under `target`.
942template <typename SizeType>
943static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
944 scf::ForOp target) {
945 // TODO: Use cheap structural assertions that targets are nested under
946 // forOp and that targets are not nested under each other when DominanceInfo
947 // exposes the capability. It seems overkill to construct a whole function
948 // dominance tree at this point.
949 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
950 assert(res.size() == 1 && "Expected 1 inner forOp");
951 return res[0];
952}
953
954SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
955 ArrayRef<Value> sizes,
956 ArrayRef<scf::ForOp> targets) {
957 SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
958 SmallVector<scf::ForOp, 8> currentTargets(targets.begin(), targets.end());
959 for (auto it : llvm::zip(forOps, sizes)) {
960 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
961 res.push_back(step);
962 currentTargets = step;
963 }
964 return res;
965}
966
967Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
968 scf::ForOp target) {
969 SmallVector<scf::ForOp, 8> res;
970 for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
971 assert(loops.size() == 1);
972 res.push_back(loops[0]);
973 }
974 return res;
975}
976
977Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
978 // Collect perfectly nested loops. If more size values provided than nested
979 // loops available, truncate `sizes`.
980 SmallVector<scf::ForOp, 4> forOps;
981 forOps.reserve(sizes.size());
982 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
983 if (forOps.size() < sizes.size())
984 sizes = sizes.take_front(N: forOps.size());
985
986 return ::tile(forOps, sizes, forOps.back());
987}
988
989void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
990 scf::ForOp root) {
991 getPerfectlyNestedLoopsImpl(nestedLoops, root);
992}
993
994TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
995 ArrayRef<int64_t> sizes) {
996 // Collect perfectly nested loops. If more size values provided than nested
997 // loops available, truncate `sizes`.
998 SmallVector<scf::ForOp, 4> forOps;
999 forOps.reserve(sizes.size());
1000 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1001 if (forOps.size() < sizes.size())
1002 sizes = sizes.take_front(N: forOps.size());
1003
1004 // Compute the tile sizes such that i-th outer loop executes size[i]
1005 // iterations. Given that the loop current executes
1006 // numIterations = ceildiv((upperBound - lowerBound), step)
1007 // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1008 SmallVector<Value, 4> tileSizes;
1009 tileSizes.reserve(N: sizes.size());
1010 for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1011 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1012
1013 auto forOp = forOps[i];
1014 OpBuilder builder(forOp);
1015 auto loc = forOp.getLoc();
1016 Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
1017 forOp.getLowerBound());
1018 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1019 Value iterationsPerBlock =
1020 ceilDivPositive(builder, loc, numIterations, sizes[i]);
1021 tileSizes.push_back(Elt: iterationsPerBlock);
1022 }
1023
1024 // Call parametric tiling with the given sizes.
1025 auto intraTile = tile(forOps, tileSizes, forOps.back());
1026 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1027
1028 // TODO: for now we just ignore the result of band isolation.
1029 // In the future, mapping decisions may be impacted by the ability to
1030 // isolate perfectly nested bands.
1031 (void)tryIsolateBands(tileLoops);
1032
1033 return tileLoops;
1034}
1035
1036scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1037 scf::ForallOp source,
1038 RewriterBase &rewriter) {
1039 unsigned numTargetOuts = target.getNumResults();
1040 unsigned numSourceOuts = source.getNumResults();
1041
1042 // Create fused shared_outs.
1043 SmallVector<Value> fusedOuts;
1044 llvm::append_range(fusedOuts, target.getOutputs());
1045 llvm::append_range(fusedOuts, source.getOutputs());
1046
1047 // Create a new scf.forall op after the source loop.
1048 rewriter.setInsertionPointAfter(source);
1049 scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
1050 source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
1051 source.getMixedStep(), fusedOuts, source.getMapping());
1052
1053 // Map control operands.
1054 IRMapping mapping;
1055 mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1056 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1057
1058 // Map shared outs.
1059 mapping.map(target.getRegionIterArgs(),
1060 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1061 mapping.map(source.getRegionIterArgs(),
1062 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1063
1064 // Append everything except the terminator into the fused operation.
1065 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1066 for (Operation &op : target.getBody()->without_terminator())
1067 rewriter.clone(op, mapping);
1068 for (Operation &op : source.getBody()->without_terminator())
1069 rewriter.clone(op, mapping);
1070
1071 // Fuse the old terminator in_parallel ops into the new one.
1072 scf::InParallelOp targetTerm = target.getTerminator();
1073 scf::InParallelOp sourceTerm = source.getTerminator();
1074 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1075 rewriter.setInsertionPointToStart(fusedTerm.getBody());
1076 for (Operation &op : targetTerm.getYieldingOps())
1077 rewriter.clone(op, mapping);
1078 for (Operation &op : sourceTerm.getYieldingOps())
1079 rewriter.clone(op, mapping);
1080
1081 // Replace old loops by substituting their uses by results of the fused loop.
1082 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1083 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1084
1085 return fusedLoop;
1086}
1087
1088scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1089 scf::ForOp source,
1090 RewriterBase &rewriter) {
1091 unsigned numTargetOuts = target.getNumResults();
1092 unsigned numSourceOuts = source.getNumResults();
1093
1094 // Create fused init_args, with target's init_args before source's init_args.
1095 SmallVector<Value> fusedInitArgs;
1096 llvm::append_range(fusedInitArgs, target.getInitArgs());
1097 llvm::append_range(fusedInitArgs, source.getInitArgs());
1098
1099 // Create a new scf.for op after the source loop (with scf.yield terminator
1100 // (without arguments) only in case its init_args is empty).
1101 rewriter.setInsertionPointAfter(source);
1102 scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
1103 source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1104 source.getStep(), fusedInitArgs);
1105
1106 // Map original induction variables and operands to those of the fused loop.
1107 IRMapping mapping;
1108 mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1109 mapping.map(target.getRegionIterArgs(),
1110 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1111 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1112 mapping.map(source.getRegionIterArgs(),
1113 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1114
1115 // Merge target's body into the new (fused) for loop and then source's body.
1116 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1117 for (Operation &op : target.getBody()->without_terminator())
1118 rewriter.clone(op, mapping);
1119 for (Operation &op : source.getBody()->without_terminator())
1120 rewriter.clone(op, mapping);
1121
1122 // Build fused yield results by appropriately mapping original yield operands.
1123 SmallVector<Value> yieldResults;
1124 for (Value operand : target.getBody()->getTerminator()->getOperands())
1125 yieldResults.push_back(mapping.lookupOrDefault(operand));
1126 for (Value operand : source.getBody()->getTerminator()->getOperands())
1127 yieldResults.push_back(mapping.lookupOrDefault(operand));
1128 if (!yieldResults.empty())
1129 rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1130
1131 // Replace old loops by substituting their uses by results of the fused loop.
1132 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1133 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1134
1135 return fusedLoop;
1136}
1137

source code of mlir/lib/Dialect/SCF/Utils/Utils.cpp