1 | //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements loop software pipelining |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/SCF/IR/SCF.h" |
15 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
16 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
18 | #include "mlir/IR/IRMapping.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Transforms/RegionUtils.h" |
21 | #include "llvm/ADT/MapVector.h" |
22 | #include "llvm/Support/Debug.h" |
23 | #include "llvm/Support/MathExtras.h" |
24 | |
25 | #define DEBUG_TYPE "scf-loop-pipelining" |
26 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
27 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::scf; |
31 | |
32 | namespace { |
33 | |
34 | /// Helper to keep internal information during pipelining transformation. |
35 | struct LoopPipelinerInternal { |
36 | /// Coarse liverange information for ops used across stages. |
37 | struct LiverangeInfo { |
38 | unsigned lastUseStage = 0; |
39 | unsigned defStage = 0; |
40 | }; |
41 | |
42 | protected: |
43 | ForOp forOp; |
44 | unsigned maxStage = 0; |
45 | DenseMap<Operation *, unsigned> stages; |
46 | std::vector<Operation *> opOrder; |
47 | Value ub; |
48 | Value lb; |
49 | Value step; |
50 | bool dynamicLoop; |
51 | PipeliningOption::AnnotationlFnType annotateFn = nullptr; |
52 | bool peelEpilogue; |
53 | PipeliningOption::PredicateOpFn predicateFn = nullptr; |
54 | |
55 | // When peeling the kernel we generate several version of each value for |
56 | // different stage of the prologue. This map tracks the mapping between |
57 | // original Values in the loop and the different versions |
58 | // peeled from the loop. |
59 | DenseMap<Value, llvm::SmallVector<Value>> valueMapping; |
60 | |
61 | /// Assign a value to `valueMapping`, this means `val` represents the version |
62 | /// `idx` of `key` in the epilogue. |
63 | void setValueMapping(Value key, Value el, int64_t idx); |
64 | |
65 | /// Return the defining op of the given value, if the Value is an argument of |
66 | /// the loop return the associated defining op in the loop and its distance to |
67 | /// the Value. |
68 | std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value); |
69 | |
70 | /// Return true if the schedule is possible and return false otherwise. A |
71 | /// schedule is correct if all definitions are scheduled before uses. |
72 | bool verifySchedule(); |
73 | |
74 | public: |
75 | /// Initalize the information for the given `op`, return true if it |
76 | /// satisfies the pre-condition to apply pipelining. |
77 | bool initializeLoopInfo(ForOp op, const PipeliningOption &options); |
78 | /// Emits the prologue, this creates `maxStage - 1` part which will contain |
79 | /// operations from stages [0; i], where i is the part index. |
80 | LogicalResult emitPrologue(RewriterBase &rewriter); |
81 | /// Gather liverange information for Values that are used in a different stage |
82 | /// than its definition. |
83 | llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues(); |
84 | scf::ForOp createKernelLoop( |
85 | const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, |
86 | RewriterBase &rewriter, |
87 | llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap); |
88 | /// Emits the pipelined kernel. This clones loop operations following user |
89 | /// order and remaps operands defined in a different stage as their use. |
90 | LogicalResult createKernel( |
91 | scf::ForOp newForOp, |
92 | const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, |
93 | const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, |
94 | RewriterBase &rewriter); |
95 | /// Emits the epilogue, this creates `maxStage - 1` part which will contain |
96 | /// operations from stages [i; maxStage], where i is the part index. |
97 | LogicalResult emitEpilogue(RewriterBase &rewriter, |
98 | llvm::SmallVector<Value> &returnValues); |
99 | }; |
100 | |
101 | bool LoopPipelinerInternal::initializeLoopInfo( |
102 | ForOp op, const PipeliningOption &options) { |
103 | LDBG("Start initializeLoopInfo" ); |
104 | forOp = op; |
105 | ub = forOp.getUpperBound(); |
106 | lb = forOp.getLowerBound(); |
107 | step = forOp.getStep(); |
108 | |
109 | dynamicLoop = true; |
110 | auto upperBoundCst = getConstantIntValue(ofr: ub); |
111 | auto lowerBoundCst = getConstantIntValue(ofr: lb); |
112 | auto stepCst = getConstantIntValue(ofr: step); |
113 | if (!upperBoundCst || !lowerBoundCst || !stepCst) { |
114 | if (!options.supportDynamicLoops) { |
115 | LDBG("--dynamic loop not supported -> BAIL" ); |
116 | return false; |
117 | } |
118 | } else { |
119 | int64_t ubImm = upperBoundCst.value(); |
120 | int64_t lbImm = lowerBoundCst.value(); |
121 | int64_t stepImm = stepCst.value(); |
122 | if (stepImm <= 0) { |
123 | LDBG("--invalid loop step -> BAIL" ); |
124 | return false; |
125 | } |
126 | int64_t numIteration = llvm::divideCeilSigned(Numerator: ubImm - lbImm, Denominator: stepImm); |
127 | if (numIteration > maxStage) { |
128 | dynamicLoop = false; |
129 | } else if (!options.supportDynamicLoops) { |
130 | LDBG("--fewer loop iterations than pipeline stages -> BAIL" ); |
131 | return false; |
132 | } |
133 | } |
134 | peelEpilogue = options.peelEpilogue; |
135 | predicateFn = options.predicateFn; |
136 | if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { |
137 | LDBG("--no epilogue or predicate set -> BAIL" ); |
138 | return false; |
139 | } |
140 | std::vector<std::pair<Operation *, unsigned>> schedule; |
141 | options.getScheduleFn(forOp, schedule); |
142 | if (schedule.empty()) { |
143 | LDBG("--empty schedule -> BAIL" ); |
144 | return false; |
145 | } |
146 | |
147 | opOrder.reserve(n: schedule.size()); |
148 | for (auto &opSchedule : schedule) { |
149 | maxStage = std::max(a: maxStage, b: opSchedule.second); |
150 | stages[opSchedule.first] = opSchedule.second; |
151 | opOrder.push_back(x: opSchedule.first); |
152 | } |
153 | |
154 | // All operations need to have a stage. |
155 | for (Operation &op : forOp.getBody()->without_terminator()) { |
156 | if (!stages.contains(&op)) { |
157 | op.emitOpError("not assigned a pipeline stage" ); |
158 | LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL" ); |
159 | return false; |
160 | } |
161 | } |
162 | |
163 | if (!verifySchedule()) { |
164 | LDBG("--invalid schedule: " << op << " -> BAIL" ); |
165 | return false; |
166 | } |
167 | |
168 | // Currently, we do not support assigning stages to ops in nested regions. The |
169 | // block of all operations assigned a stage should be the single `scf.for` |
170 | // body block. |
171 | for (const auto &[op, stageNum] : stages) { |
172 | (void)stageNum; |
173 | if (op == forOp.getBody()->getTerminator()) { |
174 | op->emitError(message: "terminator should not be assigned a stage" ); |
175 | LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL" ); |
176 | return false; |
177 | } |
178 | if (op->getBlock() != forOp.getBody()) { |
179 | op->emitOpError(message: "the owning Block of all operations assigned a stage " |
180 | "should be the loop body block" ); |
181 | LDBG("--the owning Block of all operations assigned a stage " |
182 | "should be the loop body block: " |
183 | << *op << " -> BAIL" ); |
184 | return false; |
185 | } |
186 | } |
187 | |
188 | // Support only loop-carried dependencies with a distance of one iteration or |
189 | // those defined outside of the loop. This means that any dependency within a |
190 | // loop should either be on the immediately preceding iteration, the current |
191 | // iteration, or on variables whose values are set before entering the loop. |
192 | if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), |
193 | [this](Value operand) { |
194 | Operation *def = operand.getDefiningOp(); |
195 | return !def || |
196 | (!stages.contains(def) && forOp->isAncestor(def)); |
197 | })) { |
198 | LDBG("--only support loop carried dependency with a distance of 1 or " |
199 | "defined outside of the loop -> BAIL" ); |
200 | return false; |
201 | } |
202 | annotateFn = options.annotateFn; |
203 | return true; |
204 | } |
205 | |
206 | /// Find operands of all the nested operations within `op`. |
207 | static SetVector<Value> getNestedOperands(Operation *op) { |
208 | SetVector<Value> operands; |
209 | op->walk(callback: [&](Operation *nestedOp) { |
210 | operands.insert_range(R: nestedOp->getOperands()); |
211 | }); |
212 | return operands; |
213 | } |
214 | |
215 | /// Compute unrolled cycles of each op (consumer) and verify that each op is |
216 | /// scheduled after its operands (producers) while adjusting for the distance |
217 | /// between producer and consumer. |
218 | bool LoopPipelinerInternal::verifySchedule() { |
219 | int64_t numCylesPerIter = opOrder.size(); |
220 | // Pre-compute the unrolled cycle of each op. |
221 | DenseMap<Operation *, int64_t> unrolledCyles; |
222 | for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { |
223 | Operation *def = opOrder[cycle]; |
224 | auto it = stages.find(Val: def); |
225 | assert(it != stages.end()); |
226 | int64_t stage = it->second; |
227 | unrolledCyles[def] = cycle + stage * numCylesPerIter; |
228 | } |
229 | for (Operation *consumer : opOrder) { |
230 | int64_t consumerCycle = unrolledCyles[consumer]; |
231 | for (Value operand : getNestedOperands(op: consumer)) { |
232 | auto [producer, distance] = getDefiningOpAndDistance(value: operand); |
233 | if (!producer) |
234 | continue; |
235 | auto it = unrolledCyles.find(Val: producer); |
236 | // Skip producer coming from outside the loop. |
237 | if (it == unrolledCyles.end()) |
238 | continue; |
239 | int64_t producerCycle = it->second; |
240 | if (consumerCycle < producerCycle - numCylesPerIter * distance) { |
241 | consumer->emitError(message: "operation scheduled before its operands" ); |
242 | return false; |
243 | } |
244 | } |
245 | } |
246 | return true; |
247 | } |
248 | |
249 | /// Clone `op` and call `callback` on the cloned op's oeprands as well as any |
250 | /// operands of nested ops that: |
251 | /// 1) aren't defined within the new op or |
252 | /// 2) are block arguments. |
253 | static Operation * |
254 | cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, |
255 | function_ref<void(OpOperand *newOperand)> callback) { |
256 | Operation *clone = rewriter.clone(op&: *op); |
257 | clone->walk<WalkOrder::PreOrder>(callback: [&](Operation *nested) { |
258 | // 'clone' itself will be visited first. |
259 | for (OpOperand &operand : nested->getOpOperands()) { |
260 | Operation *def = operand.get().getDefiningOp(); |
261 | if ((def && !clone->isAncestor(other: def)) || isa<BlockArgument>(Val: operand.get())) |
262 | callback(&operand); |
263 | } |
264 | }); |
265 | return clone; |
266 | } |
267 | |
268 | LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { |
269 | // Initialize the iteration argument to the loop initial values. |
270 | for (auto [arg, operand] : |
271 | llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { |
272 | setValueMapping(arg, operand.get(), 0); |
273 | } |
274 | auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
275 | Location loc = forOp.getLoc(); |
276 | SmallVector<Value> predicates(maxStage); |
277 | for (int64_t i = 0; i < maxStage; i++) { |
278 | if (dynamicLoop) { |
279 | Type t = ub.getType(); |
280 | // pred = ub > lb + (i * step) |
281 | Value iv = rewriter.create<arith::AddIOp>( |
282 | loc, lb, |
283 | rewriter.create<arith::MulIOp>( |
284 | loc, step, |
285 | rewriter.create<arith::ConstantOp>( |
286 | loc, rewriter.getIntegerAttr(t, i)))); |
287 | predicates[i] = rewriter.create<arith::CmpIOp>( |
288 | loc, arith::CmpIPredicate::slt, iv, ub); |
289 | } |
290 | |
291 | // special handling for induction variable as the increment is implicit. |
292 | // iv = lb + i * step |
293 | Type t = lb.getType(); |
294 | Value iv = rewriter.create<arith::AddIOp>( |
295 | loc, lb, |
296 | rewriter.create<arith::MulIOp>( |
297 | loc, step, |
298 | rewriter.create<arith::ConstantOp>(loc, |
299 | rewriter.getIntegerAttr(t, i)))); |
300 | setValueMapping(forOp.getInductionVar(), iv, i); |
301 | for (Operation *op : opOrder) { |
302 | if (stages[op] > i) |
303 | continue; |
304 | Operation *newOp = |
305 | cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) { |
306 | auto it = valueMapping.find(Val: newOperand->get()); |
307 | if (it != valueMapping.end()) { |
308 | Value replacement = it->second[i - stages[op]]; |
309 | newOperand->set(replacement); |
310 | } |
311 | }); |
312 | int predicateIdx = i - stages[op]; |
313 | if (predicates[predicateIdx]) { |
314 | OpBuilder::InsertionGuard insertGuard(rewriter); |
315 | newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); |
316 | if (newOp == nullptr) |
317 | return failure(); |
318 | } |
319 | if (annotateFn) |
320 | annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); |
321 | for (unsigned destId : llvm::seq(Begin: unsigned(0), End: op->getNumResults())) { |
322 | Value source = newOp->getResult(idx: destId); |
323 | // If the value is a loop carried dependency update the loop argument |
324 | for (OpOperand &operand : yield->getOpOperands()) { |
325 | if (operand.get() != op->getResult(destId)) |
326 | continue; |
327 | if (predicates[predicateIdx] && |
328 | !forOp.getResult(operand.getOperandNumber()).use_empty()) { |
329 | // If the value is used outside the loop, we need to make sure we |
330 | // return the correct version of it. |
331 | Value prevValue = valueMapping |
332 | [forOp.getRegionIterArgs()[operand.getOperandNumber()]] |
333 | [i - stages[op]]; |
334 | source = rewriter.create<arith::SelectOp>( |
335 | loc, predicates[predicateIdx], source, prevValue); |
336 | } |
337 | setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], |
338 | source, i - stages[op] + 1); |
339 | } |
340 | setValueMapping(key: op->getResult(idx: destId), el: newOp->getResult(idx: destId), |
341 | idx: i - stages[op]); |
342 | } |
343 | } |
344 | } |
345 | return success(); |
346 | } |
347 | |
348 | llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> |
349 | LoopPipelinerInternal::analyzeCrossStageValues() { |
350 | llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues; |
351 | for (Operation *op : opOrder) { |
352 | unsigned stage = stages[op]; |
353 | |
354 | auto analyzeOperand = [&](OpOperand &operand) { |
355 | auto [def, distance] = getDefiningOpAndDistance(value: operand.get()); |
356 | if (!def) |
357 | return; |
358 | auto defStage = stages.find(Val: def); |
359 | if (defStage == stages.end() || defStage->second == stage || |
360 | defStage->second == stage + distance) |
361 | return; |
362 | assert(stage > defStage->second); |
363 | LiverangeInfo &info = crossStageValues[operand.get()]; |
364 | info.defStage = defStage->second; |
365 | info.lastUseStage = std::max(a: info.lastUseStage, b: stage); |
366 | }; |
367 | |
368 | for (OpOperand &operand : op->getOpOperands()) |
369 | analyzeOperand(operand); |
370 | visitUsedValuesDefinedAbove(regions: op->getRegions(), callback: [&](OpOperand *operand) { |
371 | analyzeOperand(*operand); |
372 | }); |
373 | } |
374 | return crossStageValues; |
375 | } |
376 | |
377 | std::pair<Operation *, int64_t> |
378 | LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { |
379 | int64_t distance = 0; |
380 | if (auto arg = dyn_cast<BlockArgument>(Val&: value)) { |
381 | if (arg.getOwner() != forOp.getBody()) |
382 | return {nullptr, 0}; |
383 | // Ignore induction variable. |
384 | if (arg.getArgNumber() == 0) |
385 | return {nullptr, 0}; |
386 | distance++; |
387 | value = |
388 | forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); |
389 | } |
390 | Operation *def = value.getDefiningOp(); |
391 | if (!def) |
392 | return {nullptr, 0}; |
393 | return {def, distance}; |
394 | } |
395 | |
396 | scf::ForOp LoopPipelinerInternal::createKernelLoop( |
397 | const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> |
398 | &crossStageValues, |
399 | RewriterBase &rewriter, |
400 | llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) { |
401 | // Creates the list of initial values associated to values used across |
402 | // stages. The initial values come from the prologue created above. |
403 | // Keep track of the kernel argument associated to each version of the |
404 | // values passed to the kernel. |
405 | llvm::SmallVector<Value> newLoopArg; |
406 | // For existing loop argument initialize them with the right version from the |
407 | // prologue. |
408 | for (const auto &retVal : |
409 | llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { |
410 | Operation *def = retVal.value().getDefiningOp(); |
411 | assert(def && "Only support loop carried dependencies of distance of 1 or " |
412 | "outside the loop" ); |
413 | auto defStage = stages.find(def); |
414 | if (defStage != stages.end()) { |
415 | Value valueVersion = |
416 | valueMapping[forOp.getRegionIterArgs()[retVal.index()]] |
417 | [maxStage - defStage->second]; |
418 | assert(valueVersion); |
419 | newLoopArg.push_back(valueVersion); |
420 | } else { |
421 | newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); |
422 | } |
423 | } |
424 | for (auto escape : crossStageValues) { |
425 | LiverangeInfo &info = escape.second; |
426 | Value value = escape.first; |
427 | for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; |
428 | stageIdx++) { |
429 | Value valueVersion = |
430 | valueMapping[value][maxStage - info.lastUseStage + stageIdx]; |
431 | assert(valueVersion); |
432 | newLoopArg.push_back(Elt: valueVersion); |
433 | loopArgMap[std::make_pair(x&: value, y: info.lastUseStage - info.defStage - |
434 | stageIdx)] = newLoopArg.size() - 1; |
435 | } |
436 | } |
437 | |
438 | // Create the new kernel loop. When we peel the epilgue we need to peel |
439 | // `numStages - 1` iterations. Then we adjust the upper bound to remove those |
440 | // iterations. |
441 | Value newUb = forOp.getUpperBound(); |
442 | if (peelEpilogue) { |
443 | Type t = ub.getType(); |
444 | Location loc = forOp.getLoc(); |
445 | // newUb = ub - maxStage * step |
446 | Value maxStageValue = rewriter.create<arith::ConstantOp>( |
447 | loc, rewriter.getIntegerAttr(t, maxStage)); |
448 | Value maxStageByStep = |
449 | rewriter.create<arith::MulIOp>(loc, step, maxStageValue); |
450 | newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep); |
451 | } |
452 | auto newForOp = |
453 | rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb, |
454 | forOp.getStep(), newLoopArg); |
455 | // When there are no iter args, the loop body terminator will be created. |
456 | // Since we always create it below, remove the terminator if it was created. |
457 | if (!newForOp.getBody()->empty()) |
458 | rewriter.eraseOp(op: newForOp.getBody()->getTerminator()); |
459 | return newForOp; |
460 | } |
461 | |
462 | LogicalResult LoopPipelinerInternal::createKernel( |
463 | scf::ForOp newForOp, |
464 | const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> |
465 | &crossStageValues, |
466 | const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, |
467 | RewriterBase &rewriter) { |
468 | valueMapping.clear(); |
469 | |
470 | // Create the kernel, we clone instruction based on the order given by |
471 | // user and remap operands coming from a previous stages. |
472 | rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); |
473 | IRMapping mapping; |
474 | mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); |
475 | for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { |
476 | mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); |
477 | } |
478 | SmallVector<Value> predicates(maxStage + 1, nullptr); |
479 | if (!peelEpilogue) { |
480 | // Create a predicate for each stage except the last stage. |
481 | Location loc = newForOp.getLoc(); |
482 | Type t = ub.getType(); |
483 | for (unsigned i = 0; i < maxStage; i++) { |
484 | // c = ub - (maxStage - i) * step |
485 | Value c = rewriter.create<arith::SubIOp>( |
486 | loc, ub, |
487 | rewriter.create<arith::MulIOp>( |
488 | loc, step, |
489 | rewriter.create<arith::ConstantOp>( |
490 | loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); |
491 | |
492 | Value pred = rewriter.create<arith::CmpIOp>( |
493 | newForOp.getLoc(), arith::CmpIPredicate::slt, |
494 | newForOp.getInductionVar(), c); |
495 | predicates[i] = pred; |
496 | } |
497 | } |
498 | for (Operation *op : opOrder) { |
499 | int64_t useStage = stages[op]; |
500 | auto *newOp = rewriter.clone(op&: *op, mapper&: mapping); |
501 | SmallVector<OpOperand *> operands; |
502 | // Collect all the operands for the cloned op and its nested ops. |
503 | op->walk(callback: [&operands](Operation *nestedOp) { |
504 | for (OpOperand &operand : nestedOp->getOpOperands()) { |
505 | operands.push_back(Elt: &operand); |
506 | } |
507 | }); |
508 | for (OpOperand *operand : operands) { |
509 | Operation *nestedNewOp = mapping.lookup(from: operand->getOwner()); |
510 | // Special case for the induction variable uses. We replace it with a |
511 | // version incremented based on the stage where it is used. |
512 | if (operand->get() == forOp.getInductionVar()) { |
513 | rewriter.setInsertionPoint(newOp); |
514 | |
515 | // offset = (maxStage - stages[op]) * step |
516 | Type t = step.getType(); |
517 | Value offset = rewriter.create<arith::MulIOp>( |
518 | forOp.getLoc(), step, |
519 | rewriter.create<arith::ConstantOp>( |
520 | forOp.getLoc(), |
521 | rewriter.getIntegerAttr(t, maxStage - stages[op]))); |
522 | Value iv = rewriter.create<arith::AddIOp>( |
523 | forOp.getLoc(), newForOp.getInductionVar(), offset); |
524 | nestedNewOp->setOperand(idx: operand->getOperandNumber(), value: iv); |
525 | rewriter.setInsertionPointAfter(newOp); |
526 | continue; |
527 | } |
528 | Value source = operand->get(); |
529 | auto arg = dyn_cast<BlockArgument>(Val&: source); |
530 | if (arg && arg.getOwner() == forOp.getBody()) { |
531 | Value ret = forOp.getBody()->getTerminator()->getOperand( |
532 | arg.getArgNumber() - 1); |
533 | Operation *dep = ret.getDefiningOp(); |
534 | if (!dep) |
535 | continue; |
536 | auto stageDep = stages.find(Val: dep); |
537 | if (stageDep == stages.end() || stageDep->second == useStage) |
538 | continue; |
539 | // If the value is a loop carried value coming from stage N + 1 remap, |
540 | // it will become a direct use. |
541 | if (stageDep->second == useStage + 1) { |
542 | nestedNewOp->setOperand(idx: operand->getOperandNumber(), |
543 | value: mapping.lookupOrDefault(from: ret)); |
544 | continue; |
545 | } |
546 | source = ret; |
547 | } |
548 | // For operands defined in a previous stage we need to remap it to use |
549 | // the correct region argument. We look for the right version of the |
550 | // Value based on the stage where it is used. |
551 | Operation *def = source.getDefiningOp(); |
552 | if (!def) |
553 | continue; |
554 | auto stageDef = stages.find(Val: def); |
555 | if (stageDef == stages.end() || stageDef->second == useStage) |
556 | continue; |
557 | auto remap = loopArgMap.find( |
558 | Val: std::make_pair(x: operand->get(), y: useStage - stageDef->second)); |
559 | assert(remap != loopArgMap.end()); |
560 | nestedNewOp->setOperand(idx: operand->getOperandNumber(), |
561 | value: newForOp.getRegionIterArgs()[remap->second]); |
562 | } |
563 | |
564 | if (predicates[useStage]) { |
565 | OpBuilder::InsertionGuard insertGuard(rewriter); |
566 | newOp = predicateFn(rewriter, newOp, predicates[useStage]); |
567 | if (!newOp) |
568 | return failure(); |
569 | // Remap the results to the new predicated one. |
570 | for (auto values : llvm::zip(t: op->getResults(), u: newOp->getResults())) |
571 | mapping.map(from: std::get<0>(t&: values), to: std::get<1>(t&: values)); |
572 | } |
573 | if (annotateFn) |
574 | annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0); |
575 | } |
576 | |
577 | // Collect the Values that need to be returned by the forOp. For each |
578 | // value we need to have `LastUseStage - DefStage` number of versions |
579 | // returned. |
580 | // We create a mapping between original values and the associated loop |
581 | // returned values that will be needed by the epilogue. |
582 | llvm::SmallVector<Value> yieldOperands; |
583 | for (OpOperand &yieldOperand : |
584 | forOp.getBody()->getTerminator()->getOpOperands()) { |
585 | Value source = mapping.lookupOrDefault(yieldOperand.get()); |
586 | // When we don't peel the epilogue and the yield value is used outside the |
587 | // loop we need to make sure we return the version from numStages - |
588 | // defStage. |
589 | if (!peelEpilogue && |
590 | !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { |
591 | Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; |
592 | if (def) { |
593 | auto defStage = stages.find(def); |
594 | if (defStage != stages.end() && defStage->second < maxStage) { |
595 | Value pred = predicates[defStage->second]; |
596 | source = rewriter.create<arith::SelectOp>( |
597 | pred.getLoc(), pred, source, |
598 | newForOp.getBody() |
599 | ->getArguments()[yieldOperand.getOperandNumber() + 1]); |
600 | } |
601 | } |
602 | } |
603 | yieldOperands.push_back(source); |
604 | } |
605 | |
606 | for (auto &it : crossStageValues) { |
607 | int64_t version = maxStage - it.second.lastUseStage + 1; |
608 | unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; |
609 | // add the original version to yield ops. |
610 | // If there is a live range spanning across more than 2 stages we need to |
611 | // add extra arg. |
612 | for (unsigned i = 1; i < numVersionReturned; i++) { |
613 | setValueMapping(key: it.first, el: newForOp->getResult(yieldOperands.size()), |
614 | idx: version++); |
615 | yieldOperands.push_back( |
616 | Elt: newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + |
617 | newForOp.getNumInductionVars()]); |
618 | } |
619 | setValueMapping(key: it.first, el: newForOp->getResult(yieldOperands.size()), |
620 | idx: version++); |
621 | yieldOperands.push_back(Elt: mapping.lookupOrDefault(from: it.first)); |
622 | } |
623 | // Map the yield operand to the forOp returned value. |
624 | for (const auto &retVal : |
625 | llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { |
626 | Operation *def = retVal.value().getDefiningOp(); |
627 | assert(def && "Only support loop carried dependencies of distance of 1 or " |
628 | "defined outside the loop" ); |
629 | auto defStage = stages.find(def); |
630 | if (defStage == stages.end()) { |
631 | for (unsigned int stage = 1; stage <= maxStage; stage++) |
632 | setValueMapping(forOp.getRegionIterArgs()[retVal.index()], |
633 | retVal.value(), stage); |
634 | } else if (defStage->second > 0) { |
635 | setValueMapping(forOp.getRegionIterArgs()[retVal.index()], |
636 | newForOp->getResult(retVal.index()), |
637 | maxStage - defStage->second + 1); |
638 | } |
639 | } |
640 | rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands); |
641 | return success(); |
642 | } |
643 | |
644 | LogicalResult |
645 | LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, |
646 | llvm::SmallVector<Value> &returnValues) { |
647 | Location loc = forOp.getLoc(); |
648 | Type t = lb.getType(); |
649 | |
650 | // Emit different versions of the induction variable. They will be |
651 | // removed by dead code if not used. |
652 | |
653 | auto createConst = [&](int v) { |
654 | return rewriter.create<arith::ConstantOp>(loc, |
655 | rewriter.getIntegerAttr(t, v)); |
656 | }; |
657 | |
658 | // total_iterations = cdiv(range_diff, step); |
659 | // - range_diff = ub - lb |
660 | // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step |
661 | Value zero = createConst(0); |
662 | Value one = createConst(1); |
663 | Value stepLessZero = rewriter.create<arith::CmpIOp>( |
664 | loc, arith::CmpIPredicate::slt, step, zero); |
665 | Value stepDecr = |
666 | rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1)); |
667 | |
668 | Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb); |
669 | Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step); |
670 | Value rangeDecr = |
671 | rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr); |
672 | Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step); |
673 | |
674 | // If total_iters < max_stage, start the epilogue at zero to match the |
675 | // ramp-up in the prologue. |
676 | // start_iter = max(0, total_iters - max_stage) |
677 | Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations, |
678 | createConst(maxStage)); |
679 | iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI); |
680 | |
681 | // Capture predicates for dynamic loops. |
682 | SmallVector<Value> predicates(maxStage + 1); |
683 | |
684 | for (int64_t i = 1; i <= maxStage; i++) { |
685 | // newLastIter = lb + step * iterI |
686 | Value newlastIter = rewriter.create<arith::AddIOp>( |
687 | loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI)); |
688 | |
689 | setValueMapping(forOp.getInductionVar(), newlastIter, i); |
690 | |
691 | // increment to next iterI |
692 | iterI = rewriter.create<arith::AddIOp>(loc, iterI, one); |
693 | |
694 | if (dynamicLoop) { |
695 | // Disable stages when `i` is greater than total_iters. |
696 | // pred = total_iters >= i |
697 | predicates[i] = rewriter.create<arith::CmpIOp>( |
698 | loc, arith::CmpIPredicate::sge, totalIterations, createConst(i)); |
699 | } |
700 | } |
701 | |
702 | // Emit `maxStage - 1` epilogue part that includes operations from stages |
703 | // [i; maxStage]. |
704 | for (int64_t i = 1; i <= maxStage; i++) { |
705 | SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size()); |
706 | for (Operation *op : opOrder) { |
707 | if (stages[op] < i) |
708 | continue; |
709 | unsigned currentVersion = maxStage - stages[op] + i; |
710 | unsigned nextVersion = currentVersion + 1; |
711 | Operation *newOp = |
712 | cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) { |
713 | auto it = valueMapping.find(Val: newOperand->get()); |
714 | if (it != valueMapping.end()) { |
715 | Value replacement = it->second[currentVersion]; |
716 | newOperand->set(replacement); |
717 | } |
718 | }); |
719 | if (dynamicLoop) { |
720 | OpBuilder::InsertionGuard insertGuard(rewriter); |
721 | newOp = predicateFn(rewriter, newOp, predicates[currentVersion]); |
722 | if (!newOp) |
723 | return failure(); |
724 | } |
725 | if (annotateFn) |
726 | annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); |
727 | |
728 | for (auto [opRes, newRes] : |
729 | llvm::zip(t: op->getResults(), u: newOp->getResults())) { |
730 | setValueMapping(key: opRes, el: newRes, idx: currentVersion); |
731 | // If the value is a loop carried dependency update the loop argument |
732 | // mapping and keep track of the last version to replace the original |
733 | // forOp uses. |
734 | for (OpOperand &operand : |
735 | forOp.getBody()->getTerminator()->getOpOperands()) { |
736 | if (operand.get() != opRes) |
737 | continue; |
738 | // If the version is greater than maxStage it means it maps to the |
739 | // original forOp returned value. |
740 | unsigned ri = operand.getOperandNumber(); |
741 | returnValues[ri] = newRes; |
742 | Value mapVal = forOp.getRegionIterArgs()[ri]; |
743 | returnMap[ri] = std::make_pair(mapVal, currentVersion); |
744 | if (nextVersion <= maxStage) |
745 | setValueMapping(mapVal, newRes, nextVersion); |
746 | } |
747 | } |
748 | } |
749 | if (dynamicLoop) { |
750 | // Select return values from this stage (live outs) based on predication. |
751 | // If the stage is valid select the peeled value, else use previous stage |
752 | // value. |
753 | for (auto pair : llvm::enumerate(First&: returnValues)) { |
754 | unsigned ri = pair.index(); |
755 | auto [mapVal, currentVersion] = returnMap[ri]; |
756 | if (mapVal) { |
757 | unsigned nextVersion = currentVersion + 1; |
758 | Value pred = predicates[currentVersion]; |
759 | Value prevValue = valueMapping[mapVal][currentVersion]; |
760 | auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(), |
761 | prevValue); |
762 | returnValues[ri] = selOp; |
763 | if (nextVersion <= maxStage) |
764 | setValueMapping(key: mapVal, el: selOp, idx: nextVersion); |
765 | } |
766 | } |
767 | } |
768 | } |
769 | return success(); |
770 | } |
771 | |
772 | void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { |
773 | auto it = valueMapping.find(Val: key); |
774 | // If the value is not in the map yet add a vector big enough to store all |
775 | // versions. |
776 | if (it == valueMapping.end()) |
777 | it = |
778 | valueMapping |
779 | .insert(KV: std::make_pair(x&: key, y: llvm::SmallVector<Value>(maxStage + 1))) |
780 | .first; |
781 | it->second[idx] = el; |
782 | } |
783 | |
784 | } // namespace |
785 | |
786 | FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, |
787 | const PipeliningOption &options, |
788 | bool *modifiedIR) { |
789 | if (modifiedIR) |
790 | *modifiedIR = false; |
791 | LoopPipelinerInternal pipeliner; |
792 | if (!pipeliner.initializeLoopInfo(op: forOp, options)) |
793 | return failure(); |
794 | |
795 | if (modifiedIR) |
796 | *modifiedIR = true; |
797 | |
798 | // 1. Emit prologue. |
799 | if (failed(Result: pipeliner.emitPrologue(rewriter))) |
800 | return failure(); |
801 | |
802 | // 2. Track values used across stages. When a value cross stages it will |
803 | // need to be passed as loop iteration arguments. |
804 | // We first collect the values that are used in a different stage than where |
805 | // they are defined. |
806 | llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> |
807 | crossStageValues = pipeliner.analyzeCrossStageValues(); |
808 | |
809 | // Mapping between original loop values used cross stage and the block |
810 | // arguments associated after pipelining. A Value may map to several |
811 | // arguments if its liverange spans across more than 2 stages. |
812 | llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap; |
813 | // 3. Create the new kernel loop and return the block arguments mapping. |
814 | ForOp newForOp = |
815 | pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); |
816 | // Create the kernel block, order ops based on user choice and remap |
817 | // operands. |
818 | if (failed(pipeliner.createKernel(newForOp: newForOp, crossStageValues, loopArgMap, |
819 | rewriter))) |
820 | return failure(); |
821 | |
822 | llvm::SmallVector<Value> returnValues = |
823 | newForOp.getResults().take_front(forOp->getNumResults()); |
824 | if (options.peelEpilogue) { |
825 | // 4. Emit the epilogue after the new forOp. |
826 | rewriter.setInsertionPointAfter(newForOp); |
827 | if (failed(Result: pipeliner.emitEpilogue(rewriter, returnValues))) |
828 | return failure(); |
829 | } |
830 | // 5. Erase the original loop and replace the uses with the epilogue output. |
831 | if (forOp->getNumResults() > 0) |
832 | rewriter.replaceOp(forOp, returnValues); |
833 | else |
834 | rewriter.eraseOp(op: forOp); |
835 | |
836 | return newForOp; |
837 | } |
838 | |
839 | void mlir::scf::populateSCFLoopPipeliningPatterns( |
840 | RewritePatternSet &patterns, const PipeliningOption &options) { |
841 | patterns.add<ForLoopPipeliningPattern>(arg: options, args: patterns.getContext()); |
842 | } |
843 | |