1//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop software pipelining
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/SCF/Transforms/Patterns.h"
16#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17#include "mlir/Dialect/SCF/Utils/Utils.h"
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Support/MathExtras.h"
21#include "mlir/Transforms/RegionUtils.h"
22#include "llvm/ADT/MapVector.h"
23#include "llvm/Support/Debug.h"
24
25#define DEBUG_TYPE "scf-loop-pipelining"
26#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
28
29using namespace mlir;
30using namespace mlir::scf;
31
32namespace {
33
34/// Helper to keep internal information during pipelining transformation.
35struct LoopPipelinerInternal {
36 /// Coarse liverange information for ops used across stages.
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
40 };
41
42protected:
43 ForOp forOp;
44 unsigned maxStage = 0;
45 DenseMap<Operation *, unsigned> stages;
46 std::vector<Operation *> opOrder;
47 Value ub;
48 Value lb;
49 Value step;
50 bool dynamicLoop;
51 PipeliningOption::AnnotationlFnType annotateFn = nullptr;
52 bool peelEpilogue;
53 PipeliningOption::PredicateOpFn predicateFn = nullptr;
54
55 // When peeling the kernel we generate several version of each value for
56 // different stage of the prologue. This map tracks the mapping between
57 // original Values in the loop and the different versions
58 // peeled from the loop.
59 DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
60
61 /// Assign a value to `valueMapping`, this means `val` represents the version
62 /// `idx` of `key` in the epilogue.
63 void setValueMapping(Value key, Value el, int64_t idx);
64
65 /// Return the defining op of the given value, if the Value is an argument of
66 /// the loop return the associated defining op in the loop and its distance to
67 /// the Value.
68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
69
70 /// Return true if the schedule is possible and return false otherwise. A
71 /// schedule is correct if all definitions are scheduled before uses.
72 bool verifySchedule();
73
74public:
75 /// Initalize the information for the given `op`, return true if it
76 /// satisfies the pre-condition to apply pipelining.
77 bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
78 /// Emits the prologue, this creates `maxStage - 1` part which will contain
79 /// operations from stages [0; i], where i is the part index.
80 void emitPrologue(RewriterBase &rewriter);
81 /// Gather liverange information for Values that are used in a different stage
82 /// than its definition.
83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84 scf::ForOp createKernelLoop(
85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
86 RewriterBase &rewriter,
87 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
88 /// Emits the pipelined kernel. This clones loop operations following user
89 /// order and remaps operands defined in a different stage as their use.
90 LogicalResult createKernel(
91 scf::ForOp newForOp,
92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
94 RewriterBase &rewriter);
95 /// Emits the epilogue, this creates `maxStage - 1` part which will contain
96 /// operations from stages [i; maxStage], where i is the part index.
97 void emitEpilogue(RewriterBase &rewriter,
98 llvm::SmallVector<Value> &returnValues);
99};
100
101bool LoopPipelinerInternal::initializeLoopInfo(
102 ForOp op, const PipeliningOption &options) {
103 LDBG("Start initializeLoopInfo");
104 forOp = op;
105 ub = forOp.getUpperBound();
106 lb = forOp.getLowerBound();
107 step = forOp.getStep();
108
109 dynamicLoop = true;
110 auto upperBoundCst = getConstantIntValue(ofr: ub);
111 auto lowerBoundCst = getConstantIntValue(ofr: lb);
112 auto stepCst = getConstantIntValue(ofr: step);
113 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
114 if (!options.supportDynamicLoops) {
115 LDBG("--dynamic loop not supported -> BAIL");
116 return false;
117 }
118 } else {
119 int64_t ubImm = upperBoundCst.value();
120 int64_t lbImm = lowerBoundCst.value();
121 int64_t stepImm = stepCst.value();
122 int64_t numIteration = ceilDiv(lhs: ubImm - lbImm, rhs: stepImm);
123 if (numIteration > maxStage) {
124 dynamicLoop = false;
125 } else if (!options.supportDynamicLoops) {
126 LDBG("--fewer loop iterations than pipeline stages -> BAIL");
127 return false;
128 }
129 }
130 peelEpilogue = options.peelEpilogue;
131 predicateFn = options.predicateFn;
132 if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
133 LDBG("--no epilogue or predicate set -> BAIL");
134 return false;
135 }
136 if (dynamicLoop && peelEpilogue) {
137 LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
138 return false;
139 }
140 std::vector<std::pair<Operation *, unsigned>> schedule;
141 options.getScheduleFn(forOp, schedule);
142 if (schedule.empty()) {
143 LDBG("--empty schedule -> BAIL");
144 return false;
145 }
146
147 opOrder.reserve(n: schedule.size());
148 for (auto &opSchedule : schedule) {
149 maxStage = std::max(a: maxStage, b: opSchedule.second);
150 stages[opSchedule.first] = opSchedule.second;
151 opOrder.push_back(x: opSchedule.first);
152 }
153
154 // All operations need to have a stage.
155 for (Operation &op : forOp.getBody()->without_terminator()) {
156 if (!stages.contains(&op)) {
157 op.emitOpError("not assigned a pipeline stage");
158 LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
159 return false;
160 }
161 }
162
163 if (!verifySchedule()) {
164 LDBG("--invalid schedule: " << op << " -> BAIL");
165 return false;
166 }
167
168 // Currently, we do not support assigning stages to ops in nested regions. The
169 // block of all operations assigned a stage should be the single `scf.for`
170 // body block.
171 for (const auto &[op, stageNum] : stages) {
172 (void)stageNum;
173 if (op == forOp.getBody()->getTerminator()) {
174 op->emitError(message: "terminator should not be assigned a stage");
175 LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
176 return false;
177 }
178 if (op->getBlock() != forOp.getBody()) {
179 op->emitOpError(message: "the owning Block of all operations assigned a stage "
180 "should be the loop body block");
181 LDBG("--the owning Block of all operations assigned a stage "
182 "should be the loop body block: "
183 << *op << " -> BAIL");
184 return false;
185 }
186 }
187
188 // Support only loop-carried dependencies with a distance of one iteration or
189 // those defined outside of the loop. This means that any dependency within a
190 // loop should either be on the immediately preceding iteration, the current
191 // iteration, or on variables whose values are set before entering the loop.
192 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
193 [this](Value operand) {
194 Operation *def = operand.getDefiningOp();
195 return !def ||
196 (!stages.contains(def) && forOp->isAncestor(def));
197 })) {
198 LDBG("--only support loop carried dependency with a distance of 1 or "
199 "defined outside of the loop -> BAIL");
200 return false;
201 }
202 annotateFn = options.annotateFn;
203 return true;
204}
205
206/// Find operands of all the nested operations within `op`.
207static SetVector<Value> getNestedOperands(Operation *op) {
208 SetVector<Value> operands;
209 op->walk(callback: [&](Operation *nestedOp) {
210 for (Value operand : nestedOp->getOperands()) {
211 operands.insert(X: operand);
212 }
213 });
214 return operands;
215}
216
217/// Compute unrolled cycles of each op (consumer) and verify that each op is
218/// scheduled after its operands (producers) while adjusting for the distance
219/// between producer and consumer.
220bool LoopPipelinerInternal::verifySchedule() {
221 int64_t numCylesPerIter = opOrder.size();
222 // Pre-compute the unrolled cycle of each op.
223 DenseMap<Operation *, int64_t> unrolledCyles;
224 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
225 Operation *def = opOrder[cycle];
226 auto it = stages.find(Val: def);
227 assert(it != stages.end());
228 int64_t stage = it->second;
229 unrolledCyles[def] = cycle + stage * numCylesPerIter;
230 }
231 for (Operation *consumer : opOrder) {
232 int64_t consumerCycle = unrolledCyles[consumer];
233 for (Value operand : getNestedOperands(op: consumer)) {
234 auto [producer, distance] = getDefiningOpAndDistance(value: operand);
235 if (!producer)
236 continue;
237 auto it = unrolledCyles.find(Val: producer);
238 // Skip producer coming from outside the loop.
239 if (it == unrolledCyles.end())
240 continue;
241 int64_t producerCycle = it->second;
242 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
243 consumer->emitError(message: "operation scheduled before its operands");
244 return false;
245 }
246 }
247 }
248 return true;
249}
250
251/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
252/// operands of nested ops that:
253/// 1) aren't defined within the new op or
254/// 2) are block arguments.
255static Operation *
256cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
257 function_ref<void(OpOperand *newOperand)> callback) {
258 Operation *clone = rewriter.clone(op&: *op);
259 clone->walk<WalkOrder::PreOrder>(callback: [&](Operation *nested) {
260 // 'clone' itself will be visited first.
261 for (OpOperand &operand : nested->getOpOperands()) {
262 Operation *def = operand.get().getDefiningOp();
263 if ((def && !clone->isAncestor(other: def)) || isa<BlockArgument>(Val: operand.get()))
264 callback(&operand);
265 }
266 });
267 return clone;
268}
269
270void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
271 // Initialize the iteration argument to the loop initiale values.
272 for (auto [arg, operand] :
273 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
274 setValueMapping(arg, operand.get(), 0);
275 }
276 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
277 Location loc = forOp.getLoc();
278 SmallVector<Value> predicates(maxStage);
279 for (int64_t i = 0; i < maxStage; i++) {
280 if (dynamicLoop) {
281 Type t = ub.getType();
282 // pred = ub > lb + (i * step)
283 Value iv = rewriter.create<arith::AddIOp>(
284 loc, lb,
285 rewriter.create<arith::MulIOp>(
286 loc, step,
287 rewriter.create<arith::ConstantOp>(
288 loc, rewriter.getIntegerAttr(t, i))));
289 predicates[i] = rewriter.create<arith::CmpIOp>(
290 loc, arith::CmpIPredicate::slt, iv, ub);
291 }
292
293 // special handling for induction variable as the increment is implicit.
294 // iv = lb + i * step
295 Type t = lb.getType();
296 Value iv = rewriter.create<arith::AddIOp>(
297 loc, lb,
298 rewriter.create<arith::MulIOp>(
299 loc, step,
300 rewriter.create<arith::ConstantOp>(loc,
301 rewriter.getIntegerAttr(t, i))));
302 setValueMapping(forOp.getInductionVar(), iv, i);
303 for (Operation *op : opOrder) {
304 if (stages[op] > i)
305 continue;
306 Operation *newOp =
307 cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) {
308 auto it = valueMapping.find(Val: newOperand->get());
309 if (it != valueMapping.end()) {
310 Value replacement = it->second[i - stages[op]];
311 newOperand->set(replacement);
312 }
313 });
314 int predicateIdx = i - stages[op];
315 if (predicates[predicateIdx]) {
316 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
317 assert(newOp && "failed to predicate op.");
318 }
319 rewriter.setInsertionPointAfter(newOp);
320 if (annotateFn)
321 annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
322 for (unsigned destId : llvm::seq(Begin: unsigned(0), End: op->getNumResults())) {
323 setValueMapping(key: op->getResult(idx: destId), el: newOp->getResult(idx: destId),
324 idx: i - stages[op]);
325 // If the value is a loop carried dependency update the loop argument
326 // mapping.
327 for (OpOperand &operand : yield->getOpOperands()) {
328 if (operand.get() != op->getResult(destId))
329 continue;
330 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
331 newOp->getResult(destId), i - stages[op] + 1);
332 }
333 }
334 }
335 }
336}
337
338llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
339LoopPipelinerInternal::analyzeCrossStageValues() {
340 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
341 for (Operation *op : opOrder) {
342 unsigned stage = stages[op];
343
344 auto analyzeOperand = [&](OpOperand &operand) {
345 auto [def, distance] = getDefiningOpAndDistance(value: operand.get());
346 if (!def)
347 return;
348 auto defStage = stages.find(Val: def);
349 if (defStage == stages.end() || defStage->second == stage ||
350 defStage->second == stage + distance)
351 return;
352 assert(stage > defStage->second);
353 LiverangeInfo &info = crossStageValues[operand.get()];
354 info.defStage = defStage->second;
355 info.lastUseStage = std::max(a: info.lastUseStage, b: stage);
356 };
357
358 for (OpOperand &operand : op->getOpOperands())
359 analyzeOperand(operand);
360 visitUsedValuesDefinedAbove(regions: op->getRegions(), callback: [&](OpOperand *operand) {
361 analyzeOperand(*operand);
362 });
363 }
364 return crossStageValues;
365}
366
367std::pair<Operation *, int64_t>
368LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
369 int64_t distance = 0;
370 if (auto arg = dyn_cast<BlockArgument>(Val&: value)) {
371 if (arg.getOwner() != forOp.getBody())
372 return {nullptr, 0};
373 // Ignore induction variable.
374 if (arg.getArgNumber() == 0)
375 return {nullptr, 0};
376 distance++;
377 value =
378 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
379 }
380 Operation *def = value.getDefiningOp();
381 if (!def)
382 return {nullptr, 0};
383 return {def, distance};
384}
385
386scf::ForOp LoopPipelinerInternal::createKernelLoop(
387 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
388 &crossStageValues,
389 RewriterBase &rewriter,
390 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
391 // Creates the list of initial values associated to values used across
392 // stages. The initial values come from the prologue created above.
393 // Keep track of the kernel argument associated to each version of the
394 // values passed to the kernel.
395 llvm::SmallVector<Value> newLoopArg;
396 // For existing loop argument initialize them with the right version from the
397 // prologue.
398 for (const auto &retVal :
399 llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
400 Operation *def = retVal.value().getDefiningOp();
401 assert(def && "Only support loop carried dependencies of distance of 1 or "
402 "outside the loop");
403 auto defStage = stages.find(def);
404 if (defStage != stages.end()) {
405 Value valueVersion =
406 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
407 [maxStage - defStage->second];
408 assert(valueVersion);
409 newLoopArg.push_back(valueVersion);
410 } else
411 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
412 }
413 for (auto escape : crossStageValues) {
414 LiverangeInfo &info = escape.second;
415 Value value = escape.first;
416 for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
417 stageIdx++) {
418 Value valueVersion =
419 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
420 assert(valueVersion);
421 newLoopArg.push_back(Elt: valueVersion);
422 loopArgMap[std::make_pair(x&: value, y: info.lastUseStage - info.defStage -
423 stageIdx)] = newLoopArg.size() - 1;
424 }
425 }
426
427 // Create the new kernel loop. When we peel the epilgue we need to peel
428 // `numStages - 1` iterations. Then we adjust the upper bound to remove those
429 // iterations.
430 Value newUb = forOp.getUpperBound();
431 if (peelEpilogue) {
432 Type t = ub.getType();
433 Location loc = forOp.getLoc();
434 // newUb = ub - maxStage * step
435 Value maxStageValue = rewriter.create<arith::ConstantOp>(
436 loc, rewriter.getIntegerAttr(t, maxStage));
437 Value maxStageByStep =
438 rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
439 newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
440 }
441 auto newForOp =
442 rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
443 forOp.getStep(), newLoopArg);
444 // When there are no iter args, the loop body terminator will be created.
445 // Since we always create it below, remove the terminator if it was created.
446 if (!newForOp.getBody()->empty())
447 rewriter.eraseOp(op: newForOp.getBody()->getTerminator());
448 return newForOp;
449}
450
451LogicalResult LoopPipelinerInternal::createKernel(
452 scf::ForOp newForOp,
453 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
454 &crossStageValues,
455 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
456 RewriterBase &rewriter) {
457 valueMapping.clear();
458
459 // Create the kernel, we clone instruction based on the order given by
460 // user and remap operands coming from a previous stages.
461 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
462 IRMapping mapping;
463 mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
464 for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
465 mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
466 }
467 SmallVector<Value> predicates(maxStage + 1, nullptr);
468 if (!peelEpilogue) {
469 // Create a predicate for each stage except the last stage.
470 Location loc = newForOp.getLoc();
471 Type t = ub.getType();
472 for (unsigned i = 0; i < maxStage; i++) {
473 // c = ub - (maxStage - i) * step
474 Value c = rewriter.create<arith::SubIOp>(
475 loc, ub,
476 rewriter.create<arith::MulIOp>(
477 loc, step,
478 rewriter.create<arith::ConstantOp>(
479 loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
480
481 Value pred = rewriter.create<arith::CmpIOp>(
482 newForOp.getLoc(), arith::CmpIPredicate::slt,
483 newForOp.getInductionVar(), c);
484 predicates[i] = pred;
485 }
486 }
487 for (Operation *op : opOrder) {
488 int64_t useStage = stages[op];
489 auto *newOp = rewriter.clone(op&: *op, mapper&: mapping);
490 SmallVector<OpOperand *> operands;
491 // Collect all the operands for the cloned op and its nested ops.
492 op->walk(callback: [&operands](Operation *nestedOp) {
493 for (OpOperand &operand : nestedOp->getOpOperands()) {
494 operands.push_back(Elt: &operand);
495 }
496 });
497 for (OpOperand *operand : operands) {
498 Operation *nestedNewOp = mapping.lookup(from: operand->getOwner());
499 // Special case for the induction variable uses. We replace it with a
500 // version incremented based on the stage where it is used.
501 if (operand->get() == forOp.getInductionVar()) {
502 rewriter.setInsertionPoint(newOp);
503
504 // offset = (maxStage - stages[op]) * step
505 Type t = step.getType();
506 Value offset = rewriter.create<arith::MulIOp>(
507 forOp.getLoc(), step,
508 rewriter.create<arith::ConstantOp>(
509 forOp.getLoc(),
510 rewriter.getIntegerAttr(t, maxStage - stages[op])));
511 Value iv = rewriter.create<arith::AddIOp>(
512 forOp.getLoc(), newForOp.getInductionVar(), offset);
513 nestedNewOp->setOperand(idx: operand->getOperandNumber(), value: iv);
514 rewriter.setInsertionPointAfter(newOp);
515 continue;
516 }
517 Value source = operand->get();
518 auto arg = dyn_cast<BlockArgument>(Val&: source);
519 if (arg && arg.getOwner() == forOp.getBody()) {
520 Value ret = forOp.getBody()->getTerminator()->getOperand(
521 arg.getArgNumber() - 1);
522 Operation *dep = ret.getDefiningOp();
523 if (!dep)
524 continue;
525 auto stageDep = stages.find(Val: dep);
526 if (stageDep == stages.end() || stageDep->second == useStage)
527 continue;
528 // If the value is a loop carried value coming from stage N + 1 remap,
529 // it will become a direct use.
530 if (stageDep->second == useStage + 1) {
531 nestedNewOp->setOperand(idx: operand->getOperandNumber(),
532 value: mapping.lookupOrDefault(from: ret));
533 continue;
534 }
535 source = ret;
536 }
537 // For operands defined in a previous stage we need to remap it to use
538 // the correct region argument. We look for the right version of the
539 // Value based on the stage where it is used.
540 Operation *def = source.getDefiningOp();
541 if (!def)
542 continue;
543 auto stageDef = stages.find(Val: def);
544 if (stageDef == stages.end() || stageDef->second == useStage)
545 continue;
546 auto remap = loopArgMap.find(
547 Val: std::make_pair(x: operand->get(), y: useStage - stageDef->second));
548 assert(remap != loopArgMap.end());
549 nestedNewOp->setOperand(idx: operand->getOperandNumber(),
550 value: newForOp.getRegionIterArgs()[remap->second]);
551 }
552
553 if (predicates[useStage]) {
554 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
555 if (!newOp)
556 return failure();
557 // Remap the results to the new predicated one.
558 for (auto values : llvm::zip(t: op->getResults(), u: newOp->getResults()))
559 mapping.map(from: std::get<0>(t&: values), to: std::get<1>(t&: values));
560 }
561 rewriter.setInsertionPointAfter(newOp);
562 if (annotateFn)
563 annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
564 }
565
566 // Collect the Values that need to be returned by the forOp. For each
567 // value we need to have `LastUseStage - DefStage` number of versions
568 // returned.
569 // We create a mapping between original values and the associated loop
570 // returned values that will be needed by the epilogue.
571 llvm::SmallVector<Value> yieldOperands;
572 for (OpOperand &yieldOperand :
573 forOp.getBody()->getTerminator()->getOpOperands()) {
574 Value source = mapping.lookupOrDefault(yieldOperand.get());
575 // When we don't peel the epilogue and the yield value is used outside the
576 // loop we need to make sure we return the version from numStages -
577 // defStage.
578 if (!peelEpilogue &&
579 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
580 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
581 if (def) {
582 auto defStage = stages.find(def);
583 if (defStage != stages.end() && defStage->second < maxStage) {
584 Value pred = predicates[defStage->second];
585 source = rewriter.create<arith::SelectOp>(
586 pred.getLoc(), pred, source,
587 newForOp.getBody()
588 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
589 }
590 }
591 }
592 yieldOperands.push_back(source);
593 }
594
595 for (auto &it : crossStageValues) {
596 int64_t version = maxStage - it.second.lastUseStage + 1;
597 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
598 // add the original version to yield ops.
599 // If there is a live range spanning across more than 2 stages we need to
600 // add extra arg.
601 for (unsigned i = 1; i < numVersionReturned; i++) {
602 setValueMapping(key: it.first, el: newForOp->getResult(yieldOperands.size()),
603 idx: version++);
604 yieldOperands.push_back(
605 Elt: newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
606 newForOp.getNumInductionVars()]);
607 }
608 setValueMapping(key: it.first, el: newForOp->getResult(yieldOperands.size()),
609 idx: version++);
610 yieldOperands.push_back(Elt: mapping.lookupOrDefault(from: it.first));
611 }
612 // Map the yield operand to the forOp returned value.
613 for (const auto &retVal :
614 llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
615 Operation *def = retVal.value().getDefiningOp();
616 assert(def && "Only support loop carried dependencies of distance of 1 or "
617 "defined outside the loop");
618 auto defStage = stages.find(def);
619 if (defStage == stages.end()) {
620 for (unsigned int stage = 1; stage <= maxStage; stage++)
621 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
622 retVal.value(), stage);
623 } else if (defStage->second > 0) {
624 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
625 newForOp->getResult(retVal.index()),
626 maxStage - defStage->second + 1);
627 }
628 }
629 rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
630 return success();
631}
632
633void LoopPipelinerInternal::emitEpilogue(
634 RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
635 // Emit different versions of the induction variable. They will be
636 // removed by dead code if not used.
637 for (int64_t i = 0; i < maxStage; i++) {
638 Location loc = forOp.getLoc();
639 Type t = lb.getType();
640 Value minusOne =
641 rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
642 // number of iterations = ((ub - 1) - lb) / step
643 Value totalNumIteration = rewriter.create<arith::DivUIOp>(
644 loc,
645 rewriter.create<arith::SubIOp>(
646 loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
647 step);
648 // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
649 Value minusI =
650 rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
651 Value newlastIter = rewriter.create<arith::AddIOp>(
652 loc, lb,
653 rewriter.create<arith::MulIOp>(
654 loc, step,
655 rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
656 setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
657 }
658 // Emit `maxStage - 1` epilogue part that includes operations from stages
659 // [i; maxStage].
660 for (int64_t i = 1; i <= maxStage; i++) {
661 for (Operation *op : opOrder) {
662 if (stages[op] < i)
663 continue;
664 Operation *newOp =
665 cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) {
666 auto it = valueMapping.find(Val: newOperand->get());
667 if (it != valueMapping.end()) {
668 Value replacement = it->second[maxStage - stages[op] + i];
669 newOperand->set(replacement);
670 }
671 });
672 if (annotateFn)
673 annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
674 for (unsigned destId : llvm::seq(Begin: unsigned(0), End: op->getNumResults())) {
675 setValueMapping(key: op->getResult(idx: destId), el: newOp->getResult(idx: destId),
676 idx: maxStage - stages[op] + i);
677 // If the value is a loop carried dependency update the loop argument
678 // mapping and keep track of the last version to replace the original
679 // forOp uses.
680 for (OpOperand &operand :
681 forOp.getBody()->getTerminator()->getOpOperands()) {
682 if (operand.get() != op->getResult(destId))
683 continue;
684 unsigned version = maxStage - stages[op] + i + 1;
685 // If the version is greater than maxStage it means it maps to the
686 // original forOp returned value.
687 if (version > maxStage) {
688 returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
689 continue;
690 }
691 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
692 newOp->getResult(destId), version);
693 }
694 }
695 }
696 }
697}
698
699void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
700 auto it = valueMapping.find(Val: key);
701 // If the value is not in the map yet add a vector big enough to store all
702 // versions.
703 if (it == valueMapping.end())
704 it =
705 valueMapping
706 .insert(KV: std::make_pair(x&: key, y: llvm::SmallVector<Value>(maxStage + 1)))
707 .first;
708 it->second[idx] = el;
709}
710
711} // namespace
712
713FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
714 const PipeliningOption &options,
715 bool *modifiedIR) {
716 if (modifiedIR)
717 *modifiedIR = false;
718 LoopPipelinerInternal pipeliner;
719 if (!pipeliner.initializeLoopInfo(op: forOp, options))
720 return failure();
721
722 if (modifiedIR)
723 *modifiedIR = true;
724
725 // 1. Emit prologue.
726 pipeliner.emitPrologue(rewriter);
727
728 // 2. Track values used across stages. When a value cross stages it will
729 // need to be passed as loop iteration arguments.
730 // We first collect the values that are used in a different stage than where
731 // they are defined.
732 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
733 crossStageValues = pipeliner.analyzeCrossStageValues();
734
735 // Mapping between original loop values used cross stage and the block
736 // arguments associated after pipelining. A Value may map to several
737 // arguments if its liverange spans across more than 2 stages.
738 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
739 // 3. Create the new kernel loop and return the block arguments mapping.
740 ForOp newForOp =
741 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
742 // Create the kernel block, order ops based on user choice and remap
743 // operands.
744 if (failed(pipeliner.createKernel(newForOp: newForOp, crossStageValues, loopArgMap,
745 rewriter)))
746 return failure();
747
748 llvm::SmallVector<Value> returnValues =
749 newForOp.getResults().take_front(forOp->getNumResults());
750 if (options.peelEpilogue) {
751 // 4. Emit the epilogue after the new forOp.
752 rewriter.setInsertionPointAfter(newForOp);
753 pipeliner.emitEpilogue(rewriter, returnValues);
754 }
755 // 5. Erase the original loop and replace the uses with the epilogue output.
756 if (forOp->getNumResults() > 0)
757 rewriter.replaceOp(forOp, returnValues);
758 else
759 rewriter.eraseOp(op: forOp);
760
761 return newForOp;
762}
763
764void mlir::scf::populateSCFLoopPipeliningPatterns(
765 RewritePatternSet &patterns, const PipeliningOption &options) {
766 patterns.add<ForLoopPipeliningPattern>(arg: options, args: patterns.getContext());
767}
768

source code of mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp