1//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop software pipelining
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/SCF/Transforms/Patterns.h"
16#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17#include "mlir/Dialect/SCF/Utils/Utils.h"
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/RegionUtils.h"
21#include "llvm/ADT/MapVector.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/MathExtras.h"
24
25#define DEBUG_TYPE "scf-loop-pipelining"
26#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
28
29using namespace mlir;
30using namespace mlir::scf;
31
32namespace {
33
34/// Helper to keep internal information during pipelining transformation.
35struct LoopPipelinerInternal {
36 /// Coarse liverange information for ops used across stages.
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
40 };
41
42protected:
43 ForOp forOp;
44 unsigned maxStage = 0;
45 DenseMap<Operation *, unsigned> stages;
46 std::vector<Operation *> opOrder;
47 Value ub;
48 Value lb;
49 Value step;
50 bool dynamicLoop;
51 PipeliningOption::AnnotationlFnType annotateFn = nullptr;
52 bool peelEpilogue;
53 PipeliningOption::PredicateOpFn predicateFn = nullptr;
54
55 // When peeling the kernel we generate several version of each value for
56 // different stage of the prologue. This map tracks the mapping between
57 // original Values in the loop and the different versions
58 // peeled from the loop.
59 DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
60
61 /// Assign a value to `valueMapping`, this means `val` represents the version
62 /// `idx` of `key` in the epilogue.
63 void setValueMapping(Value key, Value el, int64_t idx);
64
65 /// Return the defining op of the given value, if the Value is an argument of
66 /// the loop return the associated defining op in the loop and its distance to
67 /// the Value.
68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
69
70 /// Return true if the schedule is possible and return false otherwise. A
71 /// schedule is correct if all definitions are scheduled before uses.
72 bool verifySchedule();
73
74public:
75 /// Initalize the information for the given `op`, return true if it
76 /// satisfies the pre-condition to apply pipelining.
77 bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
78 /// Emits the prologue, this creates `maxStage - 1` part which will contain
79 /// operations from stages [0; i], where i is the part index.
80 LogicalResult emitPrologue(RewriterBase &rewriter);
81 /// Gather liverange information for Values that are used in a different stage
82 /// than its definition.
83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84 scf::ForOp createKernelLoop(
85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
86 RewriterBase &rewriter,
87 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
88 /// Emits the pipelined kernel. This clones loop operations following user
89 /// order and remaps operands defined in a different stage as their use.
90 LogicalResult createKernel(
91 scf::ForOp newForOp,
92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
94 RewriterBase &rewriter);
95 /// Emits the epilogue, this creates `maxStage - 1` part which will contain
96 /// operations from stages [i; maxStage], where i is the part index.
97 LogicalResult emitEpilogue(RewriterBase &rewriter,
98 llvm::SmallVector<Value> &returnValues);
99};
100
101bool LoopPipelinerInternal::initializeLoopInfo(
102 ForOp op, const PipeliningOption &options) {
103 LDBG("Start initializeLoopInfo");
104 forOp = op;
105 ub = forOp.getUpperBound();
106 lb = forOp.getLowerBound();
107 step = forOp.getStep();
108
109 std::vector<std::pair<Operation *, unsigned>> schedule;
110 options.getScheduleFn(forOp, schedule);
111 if (schedule.empty()) {
112 LDBG("--empty schedule -> BAIL");
113 return false;
114 }
115
116 opOrder.reserve(n: schedule.size());
117 for (auto &opSchedule : schedule) {
118 maxStage = std::max(a: maxStage, b: opSchedule.second);
119 stages[opSchedule.first] = opSchedule.second;
120 opOrder.push_back(x: opSchedule.first);
121 }
122
123 dynamicLoop = true;
124 auto upperBoundCst = getConstantIntValue(ofr: ub);
125 auto lowerBoundCst = getConstantIntValue(ofr: lb);
126 auto stepCst = getConstantIntValue(ofr: step);
127 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
128 if (!options.supportDynamicLoops) {
129 LDBG("--dynamic loop not supported -> BAIL");
130 return false;
131 }
132 } else {
133 int64_t ubImm = upperBoundCst.value();
134 int64_t lbImm = lowerBoundCst.value();
135 int64_t stepImm = stepCst.value();
136 if (stepImm <= 0) {
137 LDBG("--invalid loop step -> BAIL");
138 return false;
139 }
140 int64_t numIteration = llvm::divideCeilSigned(Numerator: ubImm - lbImm, Denominator: stepImm);
141 if (numIteration >= maxStage) {
142 dynamicLoop = false;
143 } else if (!options.supportDynamicLoops) {
144 LDBG("--fewer loop iterations than pipeline stages -> BAIL");
145 return false;
146 }
147 }
148 peelEpilogue = options.peelEpilogue;
149 predicateFn = options.predicateFn;
150 if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
151 LDBG("--no epilogue or predicate set -> BAIL");
152 return false;
153 }
154
155 // All operations need to have a stage.
156 for (Operation &op : forOp.getBody()->without_terminator()) {
157 if (!stages.contains(Val: &op)) {
158 op.emitOpError(message: "not assigned a pipeline stage");
159 LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
160 return false;
161 }
162 }
163
164 if (!verifySchedule()) {
165 LDBG("--invalid schedule: " << op << " -> BAIL");
166 return false;
167 }
168
169 // Currently, we do not support assigning stages to ops in nested regions. The
170 // block of all operations assigned a stage should be the single `scf.for`
171 // body block.
172 for (const auto &[op, stageNum] : stages) {
173 (void)stageNum;
174 if (op == forOp.getBody()->getTerminator()) {
175 op->emitError(message: "terminator should not be assigned a stage");
176 LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
177 return false;
178 }
179 if (op->getBlock() != forOp.getBody()) {
180 op->emitOpError(message: "the owning Block of all operations assigned a stage "
181 "should be the loop body block");
182 LDBG("--the owning Block of all operations assigned a stage "
183 "should be the loop body block: "
184 << *op << " -> BAIL");
185 return false;
186 }
187 }
188
189 // Support only loop-carried dependencies with a distance of one iteration or
190 // those defined outside of the loop. This means that any dependency within a
191 // loop should either be on the immediately preceding iteration, the current
192 // iteration, or on variables whose values are set before entering the loop.
193 if (llvm::any_of(Range: forOp.getBody()->getTerminator()->getOperands(),
194 P: [this](Value operand) {
195 Operation *def = operand.getDefiningOp();
196 return !def ||
197 (!stages.contains(Val: def) && forOp->isAncestor(other: def));
198 })) {
199 LDBG("--only support loop carried dependency with a distance of 1 or "
200 "defined outside of the loop -> BAIL");
201 return false;
202 }
203 annotateFn = options.annotateFn;
204 return true;
205}
206
207/// Find operands of all the nested operations within `op`.
208static SetVector<Value> getNestedOperands(Operation *op) {
209 SetVector<Value> operands;
210 op->walk(callback: [&](Operation *nestedOp) {
211 operands.insert_range(R: nestedOp->getOperands());
212 });
213 return operands;
214}
215
216/// Compute unrolled cycles of each op (consumer) and verify that each op is
217/// scheduled after its operands (producers) while adjusting for the distance
218/// between producer and consumer.
219bool LoopPipelinerInternal::verifySchedule() {
220 int64_t numCylesPerIter = opOrder.size();
221 // Pre-compute the unrolled cycle of each op.
222 DenseMap<Operation *, int64_t> unrolledCyles;
223 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
224 Operation *def = opOrder[cycle];
225 auto it = stages.find(Val: def);
226 assert(it != stages.end());
227 int64_t stage = it->second;
228 unrolledCyles[def] = cycle + stage * numCylesPerIter;
229 }
230 for (Operation *consumer : opOrder) {
231 int64_t consumerCycle = unrolledCyles[consumer];
232 for (Value operand : getNestedOperands(op: consumer)) {
233 auto [producer, distance] = getDefiningOpAndDistance(value: operand);
234 if (!producer)
235 continue;
236 auto it = unrolledCyles.find(Val: producer);
237 // Skip producer coming from outside the loop.
238 if (it == unrolledCyles.end())
239 continue;
240 int64_t producerCycle = it->second;
241 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
242 consumer->emitError(message: "operation scheduled before its operands");
243 return false;
244 }
245 }
246 }
247 return true;
248}
249
250/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
251/// operands of nested ops that:
252/// 1) aren't defined within the new op or
253/// 2) are block arguments.
254static Operation *
255cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
256 function_ref<void(OpOperand *newOperand)> callback) {
257 Operation *clone = rewriter.clone(op&: *op);
258 clone->walk<WalkOrder::PreOrder>(callback: [&](Operation *nested) {
259 // 'clone' itself will be visited first.
260 for (OpOperand &operand : nested->getOpOperands()) {
261 Operation *def = operand.get().getDefiningOp();
262 if ((def && !clone->isAncestor(other: def)) || isa<BlockArgument>(Val: operand.get()))
263 callback(&operand);
264 }
265 });
266 return clone;
267}
268
269LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
270 // Initialize the iteration argument to the loop initial values.
271 for (auto [arg, operand] :
272 llvm::zip(t: forOp.getRegionIterArgs(), u: forOp.getInitsMutable())) {
273 setValueMapping(key: arg, el: operand.get(), idx: 0);
274 }
275 auto yield = cast<scf::YieldOp>(Val: forOp.getBody()->getTerminator());
276 Location loc = forOp.getLoc();
277 SmallVector<Value> predicates(maxStage);
278 for (int64_t i = 0; i < maxStage; i++) {
279 if (dynamicLoop) {
280 Type t = ub.getType();
281 // pred = ub > lb + (i * step)
282 Value iv = rewriter.create<arith::AddIOp>(
283 location: loc, args&: lb,
284 args: rewriter.create<arith::MulIOp>(
285 location: loc, args&: step,
286 args: rewriter.create<arith::ConstantOp>(
287 location: loc, args: rewriter.getIntegerAttr(type: t, value: i))));
288 predicates[i] = rewriter.create<arith::CmpIOp>(
289 location: loc, args: arith::CmpIPredicate::slt, args&: iv, args&: ub);
290 }
291
292 // special handling for induction variable as the increment is implicit.
293 // iv = lb + i * step
294 Type t = lb.getType();
295 Value iv = rewriter.create<arith::AddIOp>(
296 location: loc, args&: lb,
297 args: rewriter.create<arith::MulIOp>(
298 location: loc, args&: step,
299 args: rewriter.create<arith::ConstantOp>(location: loc,
300 args: rewriter.getIntegerAttr(type: t, value: i))));
301 setValueMapping(key: forOp.getInductionVar(), el: iv, idx: i);
302 for (Operation *op : opOrder) {
303 if (stages[op] > i)
304 continue;
305 Operation *newOp =
306 cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) {
307 auto it = valueMapping.find(Val: newOperand->get());
308 if (it != valueMapping.end()) {
309 Value replacement = it->second[i - stages[op]];
310 newOperand->set(replacement);
311 }
312 });
313 int predicateIdx = i - stages[op];
314 if (predicates[predicateIdx]) {
315 OpBuilder::InsertionGuard insertGuard(rewriter);
316 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
317 if (newOp == nullptr)
318 return failure();
319 }
320 if (annotateFn)
321 annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
322 for (unsigned destId : llvm::seq(Begin: unsigned(0), End: op->getNumResults())) {
323 Value source = newOp->getResult(idx: destId);
324 // If the value is a loop carried dependency update the loop argument
325 for (OpOperand &operand : yield->getOpOperands()) {
326 if (operand.get() != op->getResult(idx: destId))
327 continue;
328 if (predicates[predicateIdx] &&
329 !forOp.getResult(i: operand.getOperandNumber()).use_empty()) {
330 // If the value is used outside the loop, we need to make sure we
331 // return the correct version of it.
332 Value prevValue = valueMapping
333 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
334 [i - stages[op]];
335 source = rewriter.create<arith::SelectOp>(
336 location: loc, args&: predicates[predicateIdx], args&: source, args&: prevValue);
337 }
338 setValueMapping(key: forOp.getRegionIterArgs()[operand.getOperandNumber()],
339 el: source, idx: i - stages[op] + 1);
340 }
341 setValueMapping(key: op->getResult(idx: destId), el: newOp->getResult(idx: destId),
342 idx: i - stages[op]);
343 }
344 }
345 }
346 return success();
347}
348
349llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
350LoopPipelinerInternal::analyzeCrossStageValues() {
351 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
352 for (Operation *op : opOrder) {
353 unsigned stage = stages[op];
354
355 auto analyzeOperand = [&](OpOperand &operand) {
356 auto [def, distance] = getDefiningOpAndDistance(value: operand.get());
357 if (!def)
358 return;
359 auto defStage = stages.find(Val: def);
360 if (defStage == stages.end() || defStage->second == stage ||
361 defStage->second == stage + distance)
362 return;
363 assert(stage > defStage->second);
364 LiverangeInfo &info = crossStageValues[operand.get()];
365 info.defStage = defStage->second;
366 info.lastUseStage = std::max(a: info.lastUseStage, b: stage);
367 };
368
369 for (OpOperand &operand : op->getOpOperands())
370 analyzeOperand(operand);
371 visitUsedValuesDefinedAbove(regions: op->getRegions(), callback: [&](OpOperand *operand) {
372 analyzeOperand(*operand);
373 });
374 }
375 return crossStageValues;
376}
377
378std::pair<Operation *, int64_t>
379LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
380 int64_t distance = 0;
381 if (auto arg = dyn_cast<BlockArgument>(Val&: value)) {
382 if (arg.getOwner() != forOp.getBody())
383 return {nullptr, 0};
384 // Ignore induction variable.
385 if (arg.getArgNumber() == 0)
386 return {nullptr, 0};
387 distance++;
388 value =
389 forOp.getBody()->getTerminator()->getOperand(idx: arg.getArgNumber() - 1);
390 }
391 Operation *def = value.getDefiningOp();
392 if (!def)
393 return {nullptr, 0};
394 return {def, distance};
395}
396
397scf::ForOp LoopPipelinerInternal::createKernelLoop(
398 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
399 &crossStageValues,
400 RewriterBase &rewriter,
401 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
402 // Creates the list of initial values associated to values used across
403 // stages. The initial values come from the prologue created above.
404 // Keep track of the kernel argument associated to each version of the
405 // values passed to the kernel.
406 llvm::SmallVector<Value> newLoopArg;
407 // For existing loop argument initialize them with the right version from the
408 // prologue.
409 for (const auto &retVal :
410 llvm::enumerate(First: forOp.getBody()->getTerminator()->getOperands())) {
411 Operation *def = retVal.value().getDefiningOp();
412 assert(def && "Only support loop carried dependencies of distance of 1 or "
413 "outside the loop");
414 auto defStage = stages.find(Val: def);
415 if (defStage != stages.end()) {
416 Value valueVersion =
417 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
418 [maxStage - defStage->second];
419 assert(valueVersion);
420 newLoopArg.push_back(Elt: valueVersion);
421 } else {
422 newLoopArg.push_back(Elt: forOp.getInitArgs()[retVal.index()]);
423 }
424 }
425 for (auto escape : crossStageValues) {
426 LiverangeInfo &info = escape.second;
427 Value value = escape.first;
428 for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
429 stageIdx++) {
430 Value valueVersion =
431 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
432 assert(valueVersion);
433 newLoopArg.push_back(Elt: valueVersion);
434 loopArgMap[std::make_pair(x&: value, y: info.lastUseStage - info.defStage -
435 stageIdx)] = newLoopArg.size() - 1;
436 }
437 }
438
439 // Create the new kernel loop. When we peel the epilgue we need to peel
440 // `numStages - 1` iterations. Then we adjust the upper bound to remove those
441 // iterations.
442 Value newUb = forOp.getUpperBound();
443 if (peelEpilogue) {
444 Type t = ub.getType();
445 Location loc = forOp.getLoc();
446 // newUb = ub - maxStage * step
447 Value maxStageValue = rewriter.create<arith::ConstantOp>(
448 location: loc, args: rewriter.getIntegerAttr(type: t, value: maxStage));
449 Value maxStageByStep =
450 rewriter.create<arith::MulIOp>(location: loc, args&: step, args&: maxStageValue);
451 newUb = rewriter.create<arith::SubIOp>(location: loc, args&: ub, args&: maxStageByStep);
452 }
453 auto newForOp =
454 rewriter.create<scf::ForOp>(location: forOp.getLoc(), args: forOp.getLowerBound(), args&: newUb,
455 args: forOp.getStep(), args&: newLoopArg);
456 // When there are no iter args, the loop body terminator will be created.
457 // Since we always create it below, remove the terminator if it was created.
458 if (!newForOp.getBody()->empty())
459 rewriter.eraseOp(op: newForOp.getBody()->getTerminator());
460 return newForOp;
461}
462
463LogicalResult LoopPipelinerInternal::createKernel(
464 scf::ForOp newForOp,
465 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
466 &crossStageValues,
467 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
468 RewriterBase &rewriter) {
469 valueMapping.clear();
470
471 // Create the kernel, we clone instruction based on the order given by
472 // user and remap operands coming from a previous stages.
473 rewriter.setInsertionPoint(block: newForOp.getBody(), insertPoint: newForOp.getBody()->begin());
474 IRMapping mapping;
475 mapping.map(from: forOp.getInductionVar(), to: newForOp.getInductionVar());
476 for (const auto &arg : llvm::enumerate(First: forOp.getRegionIterArgs())) {
477 mapping.map(from: arg.value(), to: newForOp.getRegionIterArgs()[arg.index()]);
478 }
479 SmallVector<Value> predicates(maxStage + 1, nullptr);
480 if (!peelEpilogue) {
481 // Create a predicate for each stage except the last stage.
482 Location loc = newForOp.getLoc();
483 Type t = ub.getType();
484 for (unsigned i = 0; i < maxStage; i++) {
485 // c = ub - (maxStage - i) * step
486 Value c = rewriter.create<arith::SubIOp>(
487 location: loc, args&: ub,
488 args: rewriter.create<arith::MulIOp>(
489 location: loc, args&: step,
490 args: rewriter.create<arith::ConstantOp>(
491 location: loc, args: rewriter.getIntegerAttr(type: t, value: int64_t(maxStage - i)))));
492
493 Value pred = rewriter.create<arith::CmpIOp>(
494 location: newForOp.getLoc(), args: arith::CmpIPredicate::slt,
495 args: newForOp.getInductionVar(), args&: c);
496 predicates[i] = pred;
497 }
498 }
499 for (Operation *op : opOrder) {
500 int64_t useStage = stages[op];
501 auto *newOp = rewriter.clone(op&: *op, mapper&: mapping);
502 SmallVector<OpOperand *> operands;
503 // Collect all the operands for the cloned op and its nested ops.
504 op->walk(callback: [&operands](Operation *nestedOp) {
505 for (OpOperand &operand : nestedOp->getOpOperands()) {
506 operands.push_back(Elt: &operand);
507 }
508 });
509 for (OpOperand *operand : operands) {
510 Operation *nestedNewOp = mapping.lookup(from: operand->getOwner());
511 // Special case for the induction variable uses. We replace it with a
512 // version incremented based on the stage where it is used.
513 if (operand->get() == forOp.getInductionVar()) {
514 rewriter.setInsertionPoint(newOp);
515
516 // offset = (maxStage - stages[op]) * step
517 Type t = step.getType();
518 Value offset = rewriter.create<arith::MulIOp>(
519 location: forOp.getLoc(), args&: step,
520 args: rewriter.create<arith::ConstantOp>(
521 location: forOp.getLoc(),
522 args: rewriter.getIntegerAttr(type: t, value: maxStage - stages[op])));
523 Value iv = rewriter.create<arith::AddIOp>(
524 location: forOp.getLoc(), args: newForOp.getInductionVar(), args&: offset);
525 nestedNewOp->setOperand(idx: operand->getOperandNumber(), value: iv);
526 rewriter.setInsertionPointAfter(newOp);
527 continue;
528 }
529 Value source = operand->get();
530 auto arg = dyn_cast<BlockArgument>(Val&: source);
531 if (arg && arg.getOwner() == forOp.getBody()) {
532 Value ret = forOp.getBody()->getTerminator()->getOperand(
533 idx: arg.getArgNumber() - 1);
534 Operation *dep = ret.getDefiningOp();
535 if (!dep)
536 continue;
537 auto stageDep = stages.find(Val: dep);
538 if (stageDep == stages.end() || stageDep->second == useStage)
539 continue;
540 // If the value is a loop carried value coming from stage N + 1 remap,
541 // it will become a direct use.
542 if (stageDep->second == useStage + 1) {
543 nestedNewOp->setOperand(idx: operand->getOperandNumber(),
544 value: mapping.lookupOrDefault(from: ret));
545 continue;
546 }
547 source = ret;
548 }
549 // For operands defined in a previous stage we need to remap it to use
550 // the correct region argument. We look for the right version of the
551 // Value based on the stage where it is used.
552 Operation *def = source.getDefiningOp();
553 if (!def)
554 continue;
555 auto stageDef = stages.find(Val: def);
556 if (stageDef == stages.end() || stageDef->second == useStage)
557 continue;
558 auto remap = loopArgMap.find(
559 Val: std::make_pair(x: operand->get(), y: useStage - stageDef->second));
560 assert(remap != loopArgMap.end());
561 nestedNewOp->setOperand(idx: operand->getOperandNumber(),
562 value: newForOp.getRegionIterArgs()[remap->second]);
563 }
564
565 if (predicates[useStage]) {
566 OpBuilder::InsertionGuard insertGuard(rewriter);
567 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
568 if (!newOp)
569 return failure();
570 // Remap the results to the new predicated one.
571 for (auto values : llvm::zip(t: op->getResults(), u: newOp->getResults()))
572 mapping.map(from: std::get<0>(t&: values), to: std::get<1>(t&: values));
573 }
574 if (annotateFn)
575 annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
576 }
577
578 // Collect the Values that need to be returned by the forOp. For each
579 // value we need to have `LastUseStage - DefStage` number of versions
580 // returned.
581 // We create a mapping between original values and the associated loop
582 // returned values that will be needed by the epilogue.
583 llvm::SmallVector<Value> yieldOperands;
584 for (OpOperand &yieldOperand :
585 forOp.getBody()->getTerminator()->getOpOperands()) {
586 Value source = mapping.lookupOrDefault(from: yieldOperand.get());
587 // When we don't peel the epilogue and the yield value is used outside the
588 // loop we need to make sure we return the version from numStages -
589 // defStage.
590 if (!peelEpilogue &&
591 !forOp.getResult(i: yieldOperand.getOperandNumber()).use_empty()) {
592 Operation *def = getDefiningOpAndDistance(value: yieldOperand.get()).first;
593 if (def) {
594 auto defStage = stages.find(Val: def);
595 if (defStage != stages.end() && defStage->second < maxStage) {
596 Value pred = predicates[defStage->second];
597 source = rewriter.create<arith::SelectOp>(
598 location: pred.getLoc(), args&: pred, args&: source,
599 args&: newForOp.getBody()
600 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
601 }
602 }
603 }
604 yieldOperands.push_back(Elt: source);
605 }
606
607 for (auto &it : crossStageValues) {
608 int64_t version = maxStage - it.second.lastUseStage + 1;
609 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
610 // add the original version to yield ops.
611 // If there is a live range spanning across more than 2 stages we need to
612 // add extra arg.
613 for (unsigned i = 1; i < numVersionReturned; i++) {
614 setValueMapping(key: it.first, el: newForOp->getResult(idx: yieldOperands.size()),
615 idx: version++);
616 yieldOperands.push_back(
617 Elt: newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
618 newForOp.getNumInductionVars()]);
619 }
620 setValueMapping(key: it.first, el: newForOp->getResult(idx: yieldOperands.size()),
621 idx: version++);
622 yieldOperands.push_back(Elt: mapping.lookupOrDefault(from: it.first));
623 }
624 // Map the yield operand to the forOp returned value.
625 for (const auto &retVal :
626 llvm::enumerate(First: forOp.getBody()->getTerminator()->getOperands())) {
627 Operation *def = retVal.value().getDefiningOp();
628 assert(def && "Only support loop carried dependencies of distance of 1 or "
629 "defined outside the loop");
630 auto defStage = stages.find(Val: def);
631 if (defStage == stages.end()) {
632 for (unsigned int stage = 1; stage <= maxStage; stage++)
633 setValueMapping(key: forOp.getRegionIterArgs()[retVal.index()],
634 el: retVal.value(), idx: stage);
635 } else if (defStage->second > 0) {
636 setValueMapping(key: forOp.getRegionIterArgs()[retVal.index()],
637 el: newForOp->getResult(idx: retVal.index()),
638 idx: maxStage - defStage->second + 1);
639 }
640 }
641 rewriter.create<scf::YieldOp>(location: forOp.getLoc(), args&: yieldOperands);
642 return success();
643}
644
645LogicalResult
646LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
647 llvm::SmallVector<Value> &returnValues) {
648 Location loc = forOp.getLoc();
649 Type t = lb.getType();
650
651 // Emit different versions of the induction variable. They will be
652 // removed by dead code if not used.
653
654 auto createConst = [&](int v) {
655 return rewriter.create<arith::ConstantOp>(location: loc,
656 args: rewriter.getIntegerAttr(type: t, value: v));
657 };
658
659 // total_iterations = cdiv(range_diff, step);
660 // - range_diff = ub - lb
661 // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
662 Value zero = createConst(0);
663 Value one = createConst(1);
664 Value stepLessZero = rewriter.create<arith::CmpIOp>(
665 location: loc, args: arith::CmpIPredicate::slt, args&: step, args&: zero);
666 Value stepDecr =
667 rewriter.create<arith::SelectOp>(location: loc, args&: stepLessZero, args&: one, args: createConst(-1));
668
669 Value rangeDiff = rewriter.create<arith::SubIOp>(location: loc, args&: ub, args&: lb);
670 Value rangeIncrStep = rewriter.create<arith::AddIOp>(location: loc, args&: rangeDiff, args&: step);
671 Value rangeDecr =
672 rewriter.create<arith::AddIOp>(location: loc, args&: rangeIncrStep, args&: stepDecr);
673 Value totalIterations = rewriter.create<arith::DivSIOp>(location: loc, args&: rangeDecr, args&: step);
674
675 // If total_iters < max_stage, start the epilogue at zero to match the
676 // ramp-up in the prologue.
677 // start_iter = max(0, total_iters - max_stage)
678 Value iterI = rewriter.create<arith::SubIOp>(location: loc, args&: totalIterations,
679 args: createConst(maxStage));
680 iterI = rewriter.create<arith::MaxSIOp>(location: loc, args&: zero, args&: iterI);
681
682 // Capture predicates for dynamic loops.
683 SmallVector<Value> predicates(maxStage + 1);
684
685 for (int64_t i = 1; i <= maxStage; i++) {
686 // newLastIter = lb + step * iterI
687 Value newlastIter = rewriter.create<arith::AddIOp>(
688 location: loc, args&: lb, args: rewriter.create<arith::MulIOp>(location: loc, args&: step, args&: iterI));
689
690 setValueMapping(key: forOp.getInductionVar(), el: newlastIter, idx: i);
691
692 // increment to next iterI
693 iterI = rewriter.create<arith::AddIOp>(location: loc, args&: iterI, args&: one);
694
695 if (dynamicLoop) {
696 // Disable stages when `i` is greater than total_iters.
697 // pred = total_iters >= i
698 predicates[i] = rewriter.create<arith::CmpIOp>(
699 location: loc, args: arith::CmpIPredicate::sge, args&: totalIterations, args: createConst(i));
700 }
701 }
702
703 // Emit `maxStage - 1` epilogue part that includes operations from stages
704 // [i; maxStage].
705 for (int64_t i = 1; i <= maxStage; i++) {
706 SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
707 for (Operation *op : opOrder) {
708 if (stages[op] < i)
709 continue;
710 unsigned currentVersion = maxStage - stages[op] + i;
711 unsigned nextVersion = currentVersion + 1;
712 Operation *newOp =
713 cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) {
714 auto it = valueMapping.find(Val: newOperand->get());
715 if (it != valueMapping.end()) {
716 Value replacement = it->second[currentVersion];
717 newOperand->set(replacement);
718 }
719 });
720 if (dynamicLoop) {
721 OpBuilder::InsertionGuard insertGuard(rewriter);
722 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
723 if (!newOp)
724 return failure();
725 }
726 if (annotateFn)
727 annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
728
729 for (auto [opRes, newRes] :
730 llvm::zip(t: op->getResults(), u: newOp->getResults())) {
731 setValueMapping(key: opRes, el: newRes, idx: currentVersion);
732 // If the value is a loop carried dependency update the loop argument
733 // mapping and keep track of the last version to replace the original
734 // forOp uses.
735 for (OpOperand &operand :
736 forOp.getBody()->getTerminator()->getOpOperands()) {
737 if (operand.get() != opRes)
738 continue;
739 // If the version is greater than maxStage it means it maps to the
740 // original forOp returned value.
741 unsigned ri = operand.getOperandNumber();
742 returnValues[ri] = newRes;
743 Value mapVal = forOp.getRegionIterArgs()[ri];
744 returnMap[ri] = std::make_pair(x&: mapVal, y&: currentVersion);
745 if (nextVersion <= maxStage)
746 setValueMapping(key: mapVal, el: newRes, idx: nextVersion);
747 }
748 }
749 }
750 if (dynamicLoop) {
751 // Select return values from this stage (live outs) based on predication.
752 // If the stage is valid select the peeled value, else use previous stage
753 // value.
754 for (auto pair : llvm::enumerate(First&: returnValues)) {
755 unsigned ri = pair.index();
756 auto [mapVal, currentVersion] = returnMap[ri];
757 if (mapVal) {
758 unsigned nextVersion = currentVersion + 1;
759 Value pred = predicates[currentVersion];
760 Value prevValue = valueMapping[mapVal][currentVersion];
761 auto selOp = rewriter.create<arith::SelectOp>(location: loc, args&: pred, args&: pair.value(),
762 args&: prevValue);
763 returnValues[ri] = selOp;
764 if (nextVersion <= maxStage)
765 setValueMapping(key: mapVal, el: selOp, idx: nextVersion);
766 }
767 }
768 }
769 }
770 return success();
771}
772
773void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
774 auto it = valueMapping.find(Val: key);
775 // If the value is not in the map yet add a vector big enough to store all
776 // versions.
777 if (it == valueMapping.end())
778 it =
779 valueMapping
780 .insert(KV: std::make_pair(x&: key, y: llvm::SmallVector<Value>(maxStage + 1)))
781 .first;
782 it->second[idx] = el;
783}
784
785} // namespace
786
787FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
788 const PipeliningOption &options,
789 bool *modifiedIR) {
790 if (modifiedIR)
791 *modifiedIR = false;
792 LoopPipelinerInternal pipeliner;
793 if (!pipeliner.initializeLoopInfo(op: forOp, options))
794 return failure();
795
796 if (modifiedIR)
797 *modifiedIR = true;
798
799 // 1. Emit prologue.
800 if (failed(Result: pipeliner.emitPrologue(rewriter)))
801 return failure();
802
803 // 2. Track values used across stages. When a value cross stages it will
804 // need to be passed as loop iteration arguments.
805 // We first collect the values that are used in a different stage than where
806 // they are defined.
807 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
808 crossStageValues = pipeliner.analyzeCrossStageValues();
809
810 // Mapping between original loop values used cross stage and the block
811 // arguments associated after pipelining. A Value may map to several
812 // arguments if its liverange spans across more than 2 stages.
813 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
814 // 3. Create the new kernel loop and return the block arguments mapping.
815 ForOp newForOp =
816 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
817 // Create the kernel block, order ops based on user choice and remap
818 // operands.
819 if (failed(Result: pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
820 rewriter)))
821 return failure();
822
823 llvm::SmallVector<Value> returnValues =
824 newForOp.getResults().take_front(n: forOp->getNumResults());
825 if (options.peelEpilogue) {
826 // 4. Emit the epilogue after the new forOp.
827 rewriter.setInsertionPointAfter(newForOp);
828 if (failed(Result: pipeliner.emitEpilogue(rewriter, returnValues)))
829 return failure();
830 }
831 // 5. Erase the original loop and replace the uses with the epilogue output.
832 if (forOp->getNumResults() > 0)
833 rewriter.replaceOp(op: forOp, newValues: returnValues);
834 else
835 rewriter.eraseOp(op: forOp);
836
837 return newForOp;
838}
839
840void mlir::scf::populateSCFLoopPipeliningPatterns(
841 RewritePatternSet &patterns, const PipeliningOption &options) {
842 patterns.add<ForLoopPipeliningPattern>(arg: options, args: patterns.getContext());
843}
844

source code of mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp