1//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop software pipelining
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/SCF/Transforms/Patterns.h"
16#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17#include "mlir/Dialect/SCF/Utils/Utils.h"
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/RegionUtils.h"
21#include "llvm/ADT/MapVector.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/MathExtras.h"
24
25#define DEBUG_TYPE "scf-loop-pipelining"
26#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
28
29using namespace mlir;
30using namespace mlir::scf;
31
32namespace {
33
34/// Helper to keep internal information during pipelining transformation.
35struct LoopPipelinerInternal {
36 /// Coarse liverange information for ops used across stages.
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
40 };
41
42protected:
43 ForOp forOp;
44 unsigned maxStage = 0;
45 DenseMap<Operation *, unsigned> stages;
46 std::vector<Operation *> opOrder;
47 Value ub;
48 Value lb;
49 Value step;
50 bool dynamicLoop;
51 PipeliningOption::AnnotationlFnType annotateFn = nullptr;
52 bool peelEpilogue;
53 PipeliningOption::PredicateOpFn predicateFn = nullptr;
54
55 // When peeling the kernel we generate several version of each value for
56 // different stage of the prologue. This map tracks the mapping between
57 // original Values in the loop and the different versions
58 // peeled from the loop.
59 DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
60
61 /// Assign a value to `valueMapping`, this means `val` represents the version
62 /// `idx` of `key` in the epilogue.
63 void setValueMapping(Value key, Value el, int64_t idx);
64
65 /// Return the defining op of the given value, if the Value is an argument of
66 /// the loop return the associated defining op in the loop and its distance to
67 /// the Value.
68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
69
70 /// Return true if the schedule is possible and return false otherwise. A
71 /// schedule is correct if all definitions are scheduled before uses.
72 bool verifySchedule();
73
74public:
75 /// Initalize the information for the given `op`, return true if it
76 /// satisfies the pre-condition to apply pipelining.
77 bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
78 /// Emits the prologue, this creates `maxStage - 1` part which will contain
79 /// operations from stages [0; i], where i is the part index.
80 LogicalResult emitPrologue(RewriterBase &rewriter);
81 /// Gather liverange information for Values that are used in a different stage
82 /// than its definition.
83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84 scf::ForOp createKernelLoop(
85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
86 RewriterBase &rewriter,
87 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
88 /// Emits the pipelined kernel. This clones loop operations following user
89 /// order and remaps operands defined in a different stage as their use.
90 LogicalResult createKernel(
91 scf::ForOp newForOp,
92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
94 RewriterBase &rewriter);
95 /// Emits the epilogue, this creates `maxStage - 1` part which will contain
96 /// operations from stages [i; maxStage], where i is the part index.
97 LogicalResult emitEpilogue(RewriterBase &rewriter,
98 llvm::SmallVector<Value> &returnValues);
99};
100
101bool LoopPipelinerInternal::initializeLoopInfo(
102 ForOp op, const PipeliningOption &options) {
103 LDBG("Start initializeLoopInfo");
104 forOp = op;
105 ub = forOp.getUpperBound();
106 lb = forOp.getLowerBound();
107 step = forOp.getStep();
108
109 dynamicLoop = true;
110 auto upperBoundCst = getConstantIntValue(ofr: ub);
111 auto lowerBoundCst = getConstantIntValue(ofr: lb);
112 auto stepCst = getConstantIntValue(ofr: step);
113 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
114 if (!options.supportDynamicLoops) {
115 LDBG("--dynamic loop not supported -> BAIL");
116 return false;
117 }
118 } else {
119 int64_t ubImm = upperBoundCst.value();
120 int64_t lbImm = lowerBoundCst.value();
121 int64_t stepImm = stepCst.value();
122 if (stepImm <= 0) {
123 LDBG("--invalid loop step -> BAIL");
124 return false;
125 }
126 int64_t numIteration = llvm::divideCeilSigned(Numerator: ubImm - lbImm, Denominator: stepImm);
127 if (numIteration > maxStage) {
128 dynamicLoop = false;
129 } else if (!options.supportDynamicLoops) {
130 LDBG("--fewer loop iterations than pipeline stages -> BAIL");
131 return false;
132 }
133 }
134 peelEpilogue = options.peelEpilogue;
135 predicateFn = options.predicateFn;
136 if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
137 LDBG("--no epilogue or predicate set -> BAIL");
138 return false;
139 }
140 std::vector<std::pair<Operation *, unsigned>> schedule;
141 options.getScheduleFn(forOp, schedule);
142 if (schedule.empty()) {
143 LDBG("--empty schedule -> BAIL");
144 return false;
145 }
146
147 opOrder.reserve(n: schedule.size());
148 for (auto &opSchedule : schedule) {
149 maxStage = std::max(a: maxStage, b: opSchedule.second);
150 stages[opSchedule.first] = opSchedule.second;
151 opOrder.push_back(x: opSchedule.first);
152 }
153
154 // All operations need to have a stage.
155 for (Operation &op : forOp.getBody()->without_terminator()) {
156 if (!stages.contains(&op)) {
157 op.emitOpError("not assigned a pipeline stage");
158 LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
159 return false;
160 }
161 }
162
163 if (!verifySchedule()) {
164 LDBG("--invalid schedule: " << op << " -> BAIL");
165 return false;
166 }
167
168 // Currently, we do not support assigning stages to ops in nested regions. The
169 // block of all operations assigned a stage should be the single `scf.for`
170 // body block.
171 for (const auto &[op, stageNum] : stages) {
172 (void)stageNum;
173 if (op == forOp.getBody()->getTerminator()) {
174 op->emitError(message: "terminator should not be assigned a stage");
175 LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
176 return false;
177 }
178 if (op->getBlock() != forOp.getBody()) {
179 op->emitOpError(message: "the owning Block of all operations assigned a stage "
180 "should be the loop body block");
181 LDBG("--the owning Block of all operations assigned a stage "
182 "should be the loop body block: "
183 << *op << " -> BAIL");
184 return false;
185 }
186 }
187
188 // Support only loop-carried dependencies with a distance of one iteration or
189 // those defined outside of the loop. This means that any dependency within a
190 // loop should either be on the immediately preceding iteration, the current
191 // iteration, or on variables whose values are set before entering the loop.
192 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
193 [this](Value operand) {
194 Operation *def = operand.getDefiningOp();
195 return !def ||
196 (!stages.contains(def) && forOp->isAncestor(def));
197 })) {
198 LDBG("--only support loop carried dependency with a distance of 1 or "
199 "defined outside of the loop -> BAIL");
200 return false;
201 }
202 annotateFn = options.annotateFn;
203 return true;
204}
205
206/// Find operands of all the nested operations within `op`.
207static SetVector<Value> getNestedOperands(Operation *op) {
208 SetVector<Value> operands;
209 op->walk(callback: [&](Operation *nestedOp) {
210 operands.insert_range(R: nestedOp->getOperands());
211 });
212 return operands;
213}
214
215/// Compute unrolled cycles of each op (consumer) and verify that each op is
216/// scheduled after its operands (producers) while adjusting for the distance
217/// between producer and consumer.
218bool LoopPipelinerInternal::verifySchedule() {
219 int64_t numCylesPerIter = opOrder.size();
220 // Pre-compute the unrolled cycle of each op.
221 DenseMap<Operation *, int64_t> unrolledCyles;
222 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
223 Operation *def = opOrder[cycle];
224 auto it = stages.find(Val: def);
225 assert(it != stages.end());
226 int64_t stage = it->second;
227 unrolledCyles[def] = cycle + stage * numCylesPerIter;
228 }
229 for (Operation *consumer : opOrder) {
230 int64_t consumerCycle = unrolledCyles[consumer];
231 for (Value operand : getNestedOperands(op: consumer)) {
232 auto [producer, distance] = getDefiningOpAndDistance(value: operand);
233 if (!producer)
234 continue;
235 auto it = unrolledCyles.find(Val: producer);
236 // Skip producer coming from outside the loop.
237 if (it == unrolledCyles.end())
238 continue;
239 int64_t producerCycle = it->second;
240 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
241 consumer->emitError(message: "operation scheduled before its operands");
242 return false;
243 }
244 }
245 }
246 return true;
247}
248
249/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
250/// operands of nested ops that:
251/// 1) aren't defined within the new op or
252/// 2) are block arguments.
253static Operation *
254cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
255 function_ref<void(OpOperand *newOperand)> callback) {
256 Operation *clone = rewriter.clone(op&: *op);
257 clone->walk<WalkOrder::PreOrder>(callback: [&](Operation *nested) {
258 // 'clone' itself will be visited first.
259 for (OpOperand &operand : nested->getOpOperands()) {
260 Operation *def = operand.get().getDefiningOp();
261 if ((def && !clone->isAncestor(other: def)) || isa<BlockArgument>(Val: operand.get()))
262 callback(&operand);
263 }
264 });
265 return clone;
266}
267
268LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
269 // Initialize the iteration argument to the loop initial values.
270 for (auto [arg, operand] :
271 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
272 setValueMapping(arg, operand.get(), 0);
273 }
274 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
275 Location loc = forOp.getLoc();
276 SmallVector<Value> predicates(maxStage);
277 for (int64_t i = 0; i < maxStage; i++) {
278 if (dynamicLoop) {
279 Type t = ub.getType();
280 // pred = ub > lb + (i * step)
281 Value iv = rewriter.create<arith::AddIOp>(
282 loc, lb,
283 rewriter.create<arith::MulIOp>(
284 loc, step,
285 rewriter.create<arith::ConstantOp>(
286 loc, rewriter.getIntegerAttr(t, i))));
287 predicates[i] = rewriter.create<arith::CmpIOp>(
288 loc, arith::CmpIPredicate::slt, iv, ub);
289 }
290
291 // special handling for induction variable as the increment is implicit.
292 // iv = lb + i * step
293 Type t = lb.getType();
294 Value iv = rewriter.create<arith::AddIOp>(
295 loc, lb,
296 rewriter.create<arith::MulIOp>(
297 loc, step,
298 rewriter.create<arith::ConstantOp>(loc,
299 rewriter.getIntegerAttr(t, i))));
300 setValueMapping(forOp.getInductionVar(), iv, i);
301 for (Operation *op : opOrder) {
302 if (stages[op] > i)
303 continue;
304 Operation *newOp =
305 cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) {
306 auto it = valueMapping.find(Val: newOperand->get());
307 if (it != valueMapping.end()) {
308 Value replacement = it->second[i - stages[op]];
309 newOperand->set(replacement);
310 }
311 });
312 int predicateIdx = i - stages[op];
313 if (predicates[predicateIdx]) {
314 OpBuilder::InsertionGuard insertGuard(rewriter);
315 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
316 if (newOp == nullptr)
317 return failure();
318 }
319 if (annotateFn)
320 annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
321 for (unsigned destId : llvm::seq(Begin: unsigned(0), End: op->getNumResults())) {
322 Value source = newOp->getResult(idx: destId);
323 // If the value is a loop carried dependency update the loop argument
324 for (OpOperand &operand : yield->getOpOperands()) {
325 if (operand.get() != op->getResult(destId))
326 continue;
327 if (predicates[predicateIdx] &&
328 !forOp.getResult(operand.getOperandNumber()).use_empty()) {
329 // If the value is used outside the loop, we need to make sure we
330 // return the correct version of it.
331 Value prevValue = valueMapping
332 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
333 [i - stages[op]];
334 source = rewriter.create<arith::SelectOp>(
335 loc, predicates[predicateIdx], source, prevValue);
336 }
337 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
338 source, i - stages[op] + 1);
339 }
340 setValueMapping(key: op->getResult(idx: destId), el: newOp->getResult(idx: destId),
341 idx: i - stages[op]);
342 }
343 }
344 }
345 return success();
346}
347
348llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
349LoopPipelinerInternal::analyzeCrossStageValues() {
350 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
351 for (Operation *op : opOrder) {
352 unsigned stage = stages[op];
353
354 auto analyzeOperand = [&](OpOperand &operand) {
355 auto [def, distance] = getDefiningOpAndDistance(value: operand.get());
356 if (!def)
357 return;
358 auto defStage = stages.find(Val: def);
359 if (defStage == stages.end() || defStage->second == stage ||
360 defStage->second == stage + distance)
361 return;
362 assert(stage > defStage->second);
363 LiverangeInfo &info = crossStageValues[operand.get()];
364 info.defStage = defStage->second;
365 info.lastUseStage = std::max(a: info.lastUseStage, b: stage);
366 };
367
368 for (OpOperand &operand : op->getOpOperands())
369 analyzeOperand(operand);
370 visitUsedValuesDefinedAbove(regions: op->getRegions(), callback: [&](OpOperand *operand) {
371 analyzeOperand(*operand);
372 });
373 }
374 return crossStageValues;
375}
376
377std::pair<Operation *, int64_t>
378LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
379 int64_t distance = 0;
380 if (auto arg = dyn_cast<BlockArgument>(Val&: value)) {
381 if (arg.getOwner() != forOp.getBody())
382 return {nullptr, 0};
383 // Ignore induction variable.
384 if (arg.getArgNumber() == 0)
385 return {nullptr, 0};
386 distance++;
387 value =
388 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
389 }
390 Operation *def = value.getDefiningOp();
391 if (!def)
392 return {nullptr, 0};
393 return {def, distance};
394}
395
396scf::ForOp LoopPipelinerInternal::createKernelLoop(
397 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
398 &crossStageValues,
399 RewriterBase &rewriter,
400 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
401 // Creates the list of initial values associated to values used across
402 // stages. The initial values come from the prologue created above.
403 // Keep track of the kernel argument associated to each version of the
404 // values passed to the kernel.
405 llvm::SmallVector<Value> newLoopArg;
406 // For existing loop argument initialize them with the right version from the
407 // prologue.
408 for (const auto &retVal :
409 llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
410 Operation *def = retVal.value().getDefiningOp();
411 assert(def && "Only support loop carried dependencies of distance of 1 or "
412 "outside the loop");
413 auto defStage = stages.find(def);
414 if (defStage != stages.end()) {
415 Value valueVersion =
416 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
417 [maxStage - defStage->second];
418 assert(valueVersion);
419 newLoopArg.push_back(valueVersion);
420 } else {
421 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
422 }
423 }
424 for (auto escape : crossStageValues) {
425 LiverangeInfo &info = escape.second;
426 Value value = escape.first;
427 for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
428 stageIdx++) {
429 Value valueVersion =
430 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
431 assert(valueVersion);
432 newLoopArg.push_back(Elt: valueVersion);
433 loopArgMap[std::make_pair(x&: value, y: info.lastUseStage - info.defStage -
434 stageIdx)] = newLoopArg.size() - 1;
435 }
436 }
437
438 // Create the new kernel loop. When we peel the epilgue we need to peel
439 // `numStages - 1` iterations. Then we adjust the upper bound to remove those
440 // iterations.
441 Value newUb = forOp.getUpperBound();
442 if (peelEpilogue) {
443 Type t = ub.getType();
444 Location loc = forOp.getLoc();
445 // newUb = ub - maxStage * step
446 Value maxStageValue = rewriter.create<arith::ConstantOp>(
447 loc, rewriter.getIntegerAttr(t, maxStage));
448 Value maxStageByStep =
449 rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
450 newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
451 }
452 auto newForOp =
453 rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
454 forOp.getStep(), newLoopArg);
455 // When there are no iter args, the loop body terminator will be created.
456 // Since we always create it below, remove the terminator if it was created.
457 if (!newForOp.getBody()->empty())
458 rewriter.eraseOp(op: newForOp.getBody()->getTerminator());
459 return newForOp;
460}
461
462LogicalResult LoopPipelinerInternal::createKernel(
463 scf::ForOp newForOp,
464 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
465 &crossStageValues,
466 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
467 RewriterBase &rewriter) {
468 valueMapping.clear();
469
470 // Create the kernel, we clone instruction based on the order given by
471 // user and remap operands coming from a previous stages.
472 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
473 IRMapping mapping;
474 mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
475 for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
476 mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
477 }
478 SmallVector<Value> predicates(maxStage + 1, nullptr);
479 if (!peelEpilogue) {
480 // Create a predicate for each stage except the last stage.
481 Location loc = newForOp.getLoc();
482 Type t = ub.getType();
483 for (unsigned i = 0; i < maxStage; i++) {
484 // c = ub - (maxStage - i) * step
485 Value c = rewriter.create<arith::SubIOp>(
486 loc, ub,
487 rewriter.create<arith::MulIOp>(
488 loc, step,
489 rewriter.create<arith::ConstantOp>(
490 loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
491
492 Value pred = rewriter.create<arith::CmpIOp>(
493 newForOp.getLoc(), arith::CmpIPredicate::slt,
494 newForOp.getInductionVar(), c);
495 predicates[i] = pred;
496 }
497 }
498 for (Operation *op : opOrder) {
499 int64_t useStage = stages[op];
500 auto *newOp = rewriter.clone(op&: *op, mapper&: mapping);
501 SmallVector<OpOperand *> operands;
502 // Collect all the operands for the cloned op and its nested ops.
503 op->walk(callback: [&operands](Operation *nestedOp) {
504 for (OpOperand &operand : nestedOp->getOpOperands()) {
505 operands.push_back(Elt: &operand);
506 }
507 });
508 for (OpOperand *operand : operands) {
509 Operation *nestedNewOp = mapping.lookup(from: operand->getOwner());
510 // Special case for the induction variable uses. We replace it with a
511 // version incremented based on the stage where it is used.
512 if (operand->get() == forOp.getInductionVar()) {
513 rewriter.setInsertionPoint(newOp);
514
515 // offset = (maxStage - stages[op]) * step
516 Type t = step.getType();
517 Value offset = rewriter.create<arith::MulIOp>(
518 forOp.getLoc(), step,
519 rewriter.create<arith::ConstantOp>(
520 forOp.getLoc(),
521 rewriter.getIntegerAttr(t, maxStage - stages[op])));
522 Value iv = rewriter.create<arith::AddIOp>(
523 forOp.getLoc(), newForOp.getInductionVar(), offset);
524 nestedNewOp->setOperand(idx: operand->getOperandNumber(), value: iv);
525 rewriter.setInsertionPointAfter(newOp);
526 continue;
527 }
528 Value source = operand->get();
529 auto arg = dyn_cast<BlockArgument>(Val&: source);
530 if (arg && arg.getOwner() == forOp.getBody()) {
531 Value ret = forOp.getBody()->getTerminator()->getOperand(
532 arg.getArgNumber() - 1);
533 Operation *dep = ret.getDefiningOp();
534 if (!dep)
535 continue;
536 auto stageDep = stages.find(Val: dep);
537 if (stageDep == stages.end() || stageDep->second == useStage)
538 continue;
539 // If the value is a loop carried value coming from stage N + 1 remap,
540 // it will become a direct use.
541 if (stageDep->second == useStage + 1) {
542 nestedNewOp->setOperand(idx: operand->getOperandNumber(),
543 value: mapping.lookupOrDefault(from: ret));
544 continue;
545 }
546 source = ret;
547 }
548 // For operands defined in a previous stage we need to remap it to use
549 // the correct region argument. We look for the right version of the
550 // Value based on the stage where it is used.
551 Operation *def = source.getDefiningOp();
552 if (!def)
553 continue;
554 auto stageDef = stages.find(Val: def);
555 if (stageDef == stages.end() || stageDef->second == useStage)
556 continue;
557 auto remap = loopArgMap.find(
558 Val: std::make_pair(x: operand->get(), y: useStage - stageDef->second));
559 assert(remap != loopArgMap.end());
560 nestedNewOp->setOperand(idx: operand->getOperandNumber(),
561 value: newForOp.getRegionIterArgs()[remap->second]);
562 }
563
564 if (predicates[useStage]) {
565 OpBuilder::InsertionGuard insertGuard(rewriter);
566 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
567 if (!newOp)
568 return failure();
569 // Remap the results to the new predicated one.
570 for (auto values : llvm::zip(t: op->getResults(), u: newOp->getResults()))
571 mapping.map(from: std::get<0>(t&: values), to: std::get<1>(t&: values));
572 }
573 if (annotateFn)
574 annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
575 }
576
577 // Collect the Values that need to be returned by the forOp. For each
578 // value we need to have `LastUseStage - DefStage` number of versions
579 // returned.
580 // We create a mapping between original values and the associated loop
581 // returned values that will be needed by the epilogue.
582 llvm::SmallVector<Value> yieldOperands;
583 for (OpOperand &yieldOperand :
584 forOp.getBody()->getTerminator()->getOpOperands()) {
585 Value source = mapping.lookupOrDefault(yieldOperand.get());
586 // When we don't peel the epilogue and the yield value is used outside the
587 // loop we need to make sure we return the version from numStages -
588 // defStage.
589 if (!peelEpilogue &&
590 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
591 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
592 if (def) {
593 auto defStage = stages.find(def);
594 if (defStage != stages.end() && defStage->second < maxStage) {
595 Value pred = predicates[defStage->second];
596 source = rewriter.create<arith::SelectOp>(
597 pred.getLoc(), pred, source,
598 newForOp.getBody()
599 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
600 }
601 }
602 }
603 yieldOperands.push_back(source);
604 }
605
606 for (auto &it : crossStageValues) {
607 int64_t version = maxStage - it.second.lastUseStage + 1;
608 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
609 // add the original version to yield ops.
610 // If there is a live range spanning across more than 2 stages we need to
611 // add extra arg.
612 for (unsigned i = 1; i < numVersionReturned; i++) {
613 setValueMapping(key: it.first, el: newForOp->getResult(yieldOperands.size()),
614 idx: version++);
615 yieldOperands.push_back(
616 Elt: newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
617 newForOp.getNumInductionVars()]);
618 }
619 setValueMapping(key: it.first, el: newForOp->getResult(yieldOperands.size()),
620 idx: version++);
621 yieldOperands.push_back(Elt: mapping.lookupOrDefault(from: it.first));
622 }
623 // Map the yield operand to the forOp returned value.
624 for (const auto &retVal :
625 llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
626 Operation *def = retVal.value().getDefiningOp();
627 assert(def && "Only support loop carried dependencies of distance of 1 or "
628 "defined outside the loop");
629 auto defStage = stages.find(def);
630 if (defStage == stages.end()) {
631 for (unsigned int stage = 1; stage <= maxStage; stage++)
632 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
633 retVal.value(), stage);
634 } else if (defStage->second > 0) {
635 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
636 newForOp->getResult(retVal.index()),
637 maxStage - defStage->second + 1);
638 }
639 }
640 rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
641 return success();
642}
643
644LogicalResult
645LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
646 llvm::SmallVector<Value> &returnValues) {
647 Location loc = forOp.getLoc();
648 Type t = lb.getType();
649
650 // Emit different versions of the induction variable. They will be
651 // removed by dead code if not used.
652
653 auto createConst = [&](int v) {
654 return rewriter.create<arith::ConstantOp>(loc,
655 rewriter.getIntegerAttr(t, v));
656 };
657
658 // total_iterations = cdiv(range_diff, step);
659 // - range_diff = ub - lb
660 // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
661 Value zero = createConst(0);
662 Value one = createConst(1);
663 Value stepLessZero = rewriter.create<arith::CmpIOp>(
664 loc, arith::CmpIPredicate::slt, step, zero);
665 Value stepDecr =
666 rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
667
668 Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
669 Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
670 Value rangeDecr =
671 rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
672 Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
673
674 // If total_iters < max_stage, start the epilogue at zero to match the
675 // ramp-up in the prologue.
676 // start_iter = max(0, total_iters - max_stage)
677 Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
678 createConst(maxStage));
679 iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
680
681 // Capture predicates for dynamic loops.
682 SmallVector<Value> predicates(maxStage + 1);
683
684 for (int64_t i = 1; i <= maxStage; i++) {
685 // newLastIter = lb + step * iterI
686 Value newlastIter = rewriter.create<arith::AddIOp>(
687 loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
688
689 setValueMapping(forOp.getInductionVar(), newlastIter, i);
690
691 // increment to next iterI
692 iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
693
694 if (dynamicLoop) {
695 // Disable stages when `i` is greater than total_iters.
696 // pred = total_iters >= i
697 predicates[i] = rewriter.create<arith::CmpIOp>(
698 loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
699 }
700 }
701
702 // Emit `maxStage - 1` epilogue part that includes operations from stages
703 // [i; maxStage].
704 for (int64_t i = 1; i <= maxStage; i++) {
705 SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
706 for (Operation *op : opOrder) {
707 if (stages[op] < i)
708 continue;
709 unsigned currentVersion = maxStage - stages[op] + i;
710 unsigned nextVersion = currentVersion + 1;
711 Operation *newOp =
712 cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) {
713 auto it = valueMapping.find(Val: newOperand->get());
714 if (it != valueMapping.end()) {
715 Value replacement = it->second[currentVersion];
716 newOperand->set(replacement);
717 }
718 });
719 if (dynamicLoop) {
720 OpBuilder::InsertionGuard insertGuard(rewriter);
721 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
722 if (!newOp)
723 return failure();
724 }
725 if (annotateFn)
726 annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
727
728 for (auto [opRes, newRes] :
729 llvm::zip(t: op->getResults(), u: newOp->getResults())) {
730 setValueMapping(key: opRes, el: newRes, idx: currentVersion);
731 // If the value is a loop carried dependency update the loop argument
732 // mapping and keep track of the last version to replace the original
733 // forOp uses.
734 for (OpOperand &operand :
735 forOp.getBody()->getTerminator()->getOpOperands()) {
736 if (operand.get() != opRes)
737 continue;
738 // If the version is greater than maxStage it means it maps to the
739 // original forOp returned value.
740 unsigned ri = operand.getOperandNumber();
741 returnValues[ri] = newRes;
742 Value mapVal = forOp.getRegionIterArgs()[ri];
743 returnMap[ri] = std::make_pair(mapVal, currentVersion);
744 if (nextVersion <= maxStage)
745 setValueMapping(mapVal, newRes, nextVersion);
746 }
747 }
748 }
749 if (dynamicLoop) {
750 // Select return values from this stage (live outs) based on predication.
751 // If the stage is valid select the peeled value, else use previous stage
752 // value.
753 for (auto pair : llvm::enumerate(First&: returnValues)) {
754 unsigned ri = pair.index();
755 auto [mapVal, currentVersion] = returnMap[ri];
756 if (mapVal) {
757 unsigned nextVersion = currentVersion + 1;
758 Value pred = predicates[currentVersion];
759 Value prevValue = valueMapping[mapVal][currentVersion];
760 auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
761 prevValue);
762 returnValues[ri] = selOp;
763 if (nextVersion <= maxStage)
764 setValueMapping(key: mapVal, el: selOp, idx: nextVersion);
765 }
766 }
767 }
768 }
769 return success();
770}
771
772void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
773 auto it = valueMapping.find(Val: key);
774 // If the value is not in the map yet add a vector big enough to store all
775 // versions.
776 if (it == valueMapping.end())
777 it =
778 valueMapping
779 .insert(KV: std::make_pair(x&: key, y: llvm::SmallVector<Value>(maxStage + 1)))
780 .first;
781 it->second[idx] = el;
782}
783
784} // namespace
785
786FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
787 const PipeliningOption &options,
788 bool *modifiedIR) {
789 if (modifiedIR)
790 *modifiedIR = false;
791 LoopPipelinerInternal pipeliner;
792 if (!pipeliner.initializeLoopInfo(op: forOp, options))
793 return failure();
794
795 if (modifiedIR)
796 *modifiedIR = true;
797
798 // 1. Emit prologue.
799 if (failed(Result: pipeliner.emitPrologue(rewriter)))
800 return failure();
801
802 // 2. Track values used across stages. When a value cross stages it will
803 // need to be passed as loop iteration arguments.
804 // We first collect the values that are used in a different stage than where
805 // they are defined.
806 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
807 crossStageValues = pipeliner.analyzeCrossStageValues();
808
809 // Mapping between original loop values used cross stage and the block
810 // arguments associated after pipelining. A Value may map to several
811 // arguments if its liverange spans across more than 2 stages.
812 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
813 // 3. Create the new kernel loop and return the block arguments mapping.
814 ForOp newForOp =
815 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
816 // Create the kernel block, order ops based on user choice and remap
817 // operands.
818 if (failed(pipeliner.createKernel(newForOp: newForOp, crossStageValues, loopArgMap,
819 rewriter)))
820 return failure();
821
822 llvm::SmallVector<Value> returnValues =
823 newForOp.getResults().take_front(forOp->getNumResults());
824 if (options.peelEpilogue) {
825 // 4. Emit the epilogue after the new forOp.
826 rewriter.setInsertionPointAfter(newForOp);
827 if (failed(Result: pipeliner.emitEpilogue(rewriter, returnValues)))
828 return failure();
829 }
830 // 5. Erase the original loop and replace the uses with the epilogue output.
831 if (forOp->getNumResults() > 0)
832 rewriter.replaceOp(forOp, returnValues);
833 else
834 rewriter.eraseOp(op: forOp);
835
836 return newForOp;
837}
838
839void mlir::scf::populateSCFLoopPipeliningPatterns(
840 RewritePatternSet &patterns, const PipeliningOption &options) {
841 patterns.add<ForLoopPipeliningPattern>(arg: options, args: patterns.getContext());
842}
843

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp