1 | //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements loop software pipelining |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/SCF/IR/SCF.h" |
15 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
16 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
18 | #include "mlir/IR/IRMapping.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Support/MathExtras.h" |
21 | #include "mlir/Transforms/RegionUtils.h" |
22 | #include "llvm/ADT/MapVector.h" |
23 | #include "llvm/Support/Debug.h" |
24 | |
25 | #define DEBUG_TYPE "scf-loop-pipelining" |
26 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
27 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::scf; |
31 | |
32 | namespace { |
33 | |
34 | /// Helper to keep internal information during pipelining transformation. |
35 | struct LoopPipelinerInternal { |
36 | /// Coarse liverange information for ops used across stages. |
37 | struct LiverangeInfo { |
38 | unsigned lastUseStage = 0; |
39 | unsigned defStage = 0; |
40 | }; |
41 | |
42 | protected: |
43 | ForOp forOp; |
44 | unsigned maxStage = 0; |
45 | DenseMap<Operation *, unsigned> stages; |
46 | std::vector<Operation *> opOrder; |
47 | Value ub; |
48 | Value lb; |
49 | Value step; |
50 | bool dynamicLoop; |
51 | PipeliningOption::AnnotationlFnType annotateFn = nullptr; |
52 | bool peelEpilogue; |
53 | PipeliningOption::PredicateOpFn predicateFn = nullptr; |
54 | |
55 | // When peeling the kernel we generate several version of each value for |
56 | // different stage of the prologue. This map tracks the mapping between |
57 | // original Values in the loop and the different versions |
58 | // peeled from the loop. |
59 | DenseMap<Value, llvm::SmallVector<Value>> valueMapping; |
60 | |
61 | /// Assign a value to `valueMapping`, this means `val` represents the version |
62 | /// `idx` of `key` in the epilogue. |
63 | void setValueMapping(Value key, Value el, int64_t idx); |
64 | |
65 | /// Return the defining op of the given value, if the Value is an argument of |
66 | /// the loop return the associated defining op in the loop and its distance to |
67 | /// the Value. |
68 | std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value); |
69 | |
70 | /// Return true if the schedule is possible and return false otherwise. A |
71 | /// schedule is correct if all definitions are scheduled before uses. |
72 | bool verifySchedule(); |
73 | |
74 | public: |
75 | /// Initalize the information for the given `op`, return true if it |
76 | /// satisfies the pre-condition to apply pipelining. |
77 | bool initializeLoopInfo(ForOp op, const PipeliningOption &options); |
78 | /// Emits the prologue, this creates `maxStage - 1` part which will contain |
79 | /// operations from stages [0; i], where i is the part index. |
80 | void emitPrologue(RewriterBase &rewriter); |
81 | /// Gather liverange information for Values that are used in a different stage |
82 | /// than its definition. |
83 | llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues(); |
84 | scf::ForOp createKernelLoop( |
85 | const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, |
86 | RewriterBase &rewriter, |
87 | llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap); |
88 | /// Emits the pipelined kernel. This clones loop operations following user |
89 | /// order and remaps operands defined in a different stage as their use. |
90 | LogicalResult createKernel( |
91 | scf::ForOp newForOp, |
92 | const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, |
93 | const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, |
94 | RewriterBase &rewriter); |
95 | /// Emits the epilogue, this creates `maxStage - 1` part which will contain |
96 | /// operations from stages [i; maxStage], where i is the part index. |
97 | void emitEpilogue(RewriterBase &rewriter, |
98 | llvm::SmallVector<Value> &returnValues); |
99 | }; |
100 | |
101 | bool LoopPipelinerInternal::initializeLoopInfo( |
102 | ForOp op, const PipeliningOption &options) { |
103 | LDBG("Start initializeLoopInfo" ); |
104 | forOp = op; |
105 | ub = forOp.getUpperBound(); |
106 | lb = forOp.getLowerBound(); |
107 | step = forOp.getStep(); |
108 | |
109 | dynamicLoop = true; |
110 | auto upperBoundCst = getConstantIntValue(ofr: ub); |
111 | auto lowerBoundCst = getConstantIntValue(ofr: lb); |
112 | auto stepCst = getConstantIntValue(ofr: step); |
113 | if (!upperBoundCst || !lowerBoundCst || !stepCst) { |
114 | if (!options.supportDynamicLoops) { |
115 | LDBG("--dynamic loop not supported -> BAIL" ); |
116 | return false; |
117 | } |
118 | } else { |
119 | int64_t ubImm = upperBoundCst.value(); |
120 | int64_t lbImm = lowerBoundCst.value(); |
121 | int64_t stepImm = stepCst.value(); |
122 | int64_t numIteration = ceilDiv(lhs: ubImm - lbImm, rhs: stepImm); |
123 | if (numIteration > maxStage) { |
124 | dynamicLoop = false; |
125 | } else if (!options.supportDynamicLoops) { |
126 | LDBG("--fewer loop iterations than pipeline stages -> BAIL" ); |
127 | return false; |
128 | } |
129 | } |
130 | peelEpilogue = options.peelEpilogue; |
131 | predicateFn = options.predicateFn; |
132 | if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { |
133 | LDBG("--no epilogue or predicate set -> BAIL" ); |
134 | return false; |
135 | } |
136 | if (dynamicLoop && peelEpilogue) { |
137 | LDBG("--dynamic loop doesn't support epilogue yet -> BAIL" ); |
138 | return false; |
139 | } |
140 | std::vector<std::pair<Operation *, unsigned>> schedule; |
141 | options.getScheduleFn(forOp, schedule); |
142 | if (schedule.empty()) { |
143 | LDBG("--empty schedule -> BAIL" ); |
144 | return false; |
145 | } |
146 | |
147 | opOrder.reserve(n: schedule.size()); |
148 | for (auto &opSchedule : schedule) { |
149 | maxStage = std::max(a: maxStage, b: opSchedule.second); |
150 | stages[opSchedule.first] = opSchedule.second; |
151 | opOrder.push_back(x: opSchedule.first); |
152 | } |
153 | |
154 | // All operations need to have a stage. |
155 | for (Operation &op : forOp.getBody()->without_terminator()) { |
156 | if (!stages.contains(&op)) { |
157 | op.emitOpError("not assigned a pipeline stage" ); |
158 | LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL" ); |
159 | return false; |
160 | } |
161 | } |
162 | |
163 | if (!verifySchedule()) { |
164 | LDBG("--invalid schedule: " << op << " -> BAIL" ); |
165 | return false; |
166 | } |
167 | |
168 | // Currently, we do not support assigning stages to ops in nested regions. The |
169 | // block of all operations assigned a stage should be the single `scf.for` |
170 | // body block. |
171 | for (const auto &[op, stageNum] : stages) { |
172 | (void)stageNum; |
173 | if (op == forOp.getBody()->getTerminator()) { |
174 | op->emitError(message: "terminator should not be assigned a stage" ); |
175 | LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL" ); |
176 | return false; |
177 | } |
178 | if (op->getBlock() != forOp.getBody()) { |
179 | op->emitOpError(message: "the owning Block of all operations assigned a stage " |
180 | "should be the loop body block" ); |
181 | LDBG("--the owning Block of all operations assigned a stage " |
182 | "should be the loop body block: " |
183 | << *op << " -> BAIL" ); |
184 | return false; |
185 | } |
186 | } |
187 | |
188 | // Support only loop-carried dependencies with a distance of one iteration or |
189 | // those defined outside of the loop. This means that any dependency within a |
190 | // loop should either be on the immediately preceding iteration, the current |
191 | // iteration, or on variables whose values are set before entering the loop. |
192 | if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), |
193 | [this](Value operand) { |
194 | Operation *def = operand.getDefiningOp(); |
195 | return !def || |
196 | (!stages.contains(def) && forOp->isAncestor(def)); |
197 | })) { |
198 | LDBG("--only support loop carried dependency with a distance of 1 or " |
199 | "defined outside of the loop -> BAIL" ); |
200 | return false; |
201 | } |
202 | annotateFn = options.annotateFn; |
203 | return true; |
204 | } |
205 | |
206 | /// Find operands of all the nested operations within `op`. |
207 | static SetVector<Value> getNestedOperands(Operation *op) { |
208 | SetVector<Value> operands; |
209 | op->walk(callback: [&](Operation *nestedOp) { |
210 | for (Value operand : nestedOp->getOperands()) { |
211 | operands.insert(X: operand); |
212 | } |
213 | }); |
214 | return operands; |
215 | } |
216 | |
217 | /// Compute unrolled cycles of each op (consumer) and verify that each op is |
218 | /// scheduled after its operands (producers) while adjusting for the distance |
219 | /// between producer and consumer. |
220 | bool LoopPipelinerInternal::verifySchedule() { |
221 | int64_t numCylesPerIter = opOrder.size(); |
222 | // Pre-compute the unrolled cycle of each op. |
223 | DenseMap<Operation *, int64_t> unrolledCyles; |
224 | for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { |
225 | Operation *def = opOrder[cycle]; |
226 | auto it = stages.find(Val: def); |
227 | assert(it != stages.end()); |
228 | int64_t stage = it->second; |
229 | unrolledCyles[def] = cycle + stage * numCylesPerIter; |
230 | } |
231 | for (Operation *consumer : opOrder) { |
232 | int64_t consumerCycle = unrolledCyles[consumer]; |
233 | for (Value operand : getNestedOperands(op: consumer)) { |
234 | auto [producer, distance] = getDefiningOpAndDistance(value: operand); |
235 | if (!producer) |
236 | continue; |
237 | auto it = unrolledCyles.find(Val: producer); |
238 | // Skip producer coming from outside the loop. |
239 | if (it == unrolledCyles.end()) |
240 | continue; |
241 | int64_t producerCycle = it->second; |
242 | if (consumerCycle < producerCycle - numCylesPerIter * distance) { |
243 | consumer->emitError(message: "operation scheduled before its operands" ); |
244 | return false; |
245 | } |
246 | } |
247 | } |
248 | return true; |
249 | } |
250 | |
251 | /// Clone `op` and call `callback` on the cloned op's oeprands as well as any |
252 | /// operands of nested ops that: |
253 | /// 1) aren't defined within the new op or |
254 | /// 2) are block arguments. |
255 | static Operation * |
256 | cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, |
257 | function_ref<void(OpOperand *newOperand)> callback) { |
258 | Operation *clone = rewriter.clone(op&: *op); |
259 | clone->walk<WalkOrder::PreOrder>(callback: [&](Operation *nested) { |
260 | // 'clone' itself will be visited first. |
261 | for (OpOperand &operand : nested->getOpOperands()) { |
262 | Operation *def = operand.get().getDefiningOp(); |
263 | if ((def && !clone->isAncestor(other: def)) || isa<BlockArgument>(Val: operand.get())) |
264 | callback(&operand); |
265 | } |
266 | }); |
267 | return clone; |
268 | } |
269 | |
270 | void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { |
271 | // Initialize the iteration argument to the loop initiale values. |
272 | for (auto [arg, operand] : |
273 | llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { |
274 | setValueMapping(arg, operand.get(), 0); |
275 | } |
276 | auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
277 | Location loc = forOp.getLoc(); |
278 | SmallVector<Value> predicates(maxStage); |
279 | for (int64_t i = 0; i < maxStage; i++) { |
280 | if (dynamicLoop) { |
281 | Type t = ub.getType(); |
282 | // pred = ub > lb + (i * step) |
283 | Value iv = rewriter.create<arith::AddIOp>( |
284 | loc, lb, |
285 | rewriter.create<arith::MulIOp>( |
286 | loc, step, |
287 | rewriter.create<arith::ConstantOp>( |
288 | loc, rewriter.getIntegerAttr(t, i)))); |
289 | predicates[i] = rewriter.create<arith::CmpIOp>( |
290 | loc, arith::CmpIPredicate::slt, iv, ub); |
291 | } |
292 | |
293 | // special handling for induction variable as the increment is implicit. |
294 | // iv = lb + i * step |
295 | Type t = lb.getType(); |
296 | Value iv = rewriter.create<arith::AddIOp>( |
297 | loc, lb, |
298 | rewriter.create<arith::MulIOp>( |
299 | loc, step, |
300 | rewriter.create<arith::ConstantOp>(loc, |
301 | rewriter.getIntegerAttr(t, i)))); |
302 | setValueMapping(forOp.getInductionVar(), iv, i); |
303 | for (Operation *op : opOrder) { |
304 | if (stages[op] > i) |
305 | continue; |
306 | Operation *newOp = |
307 | cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) { |
308 | auto it = valueMapping.find(Val: newOperand->get()); |
309 | if (it != valueMapping.end()) { |
310 | Value replacement = it->second[i - stages[op]]; |
311 | newOperand->set(replacement); |
312 | } |
313 | }); |
314 | int predicateIdx = i - stages[op]; |
315 | if (predicates[predicateIdx]) { |
316 | newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); |
317 | assert(newOp && "failed to predicate op." ); |
318 | } |
319 | rewriter.setInsertionPointAfter(newOp); |
320 | if (annotateFn) |
321 | annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); |
322 | for (unsigned destId : llvm::seq(Begin: unsigned(0), End: op->getNumResults())) { |
323 | setValueMapping(key: op->getResult(idx: destId), el: newOp->getResult(idx: destId), |
324 | idx: i - stages[op]); |
325 | // If the value is a loop carried dependency update the loop argument |
326 | // mapping. |
327 | for (OpOperand &operand : yield->getOpOperands()) { |
328 | if (operand.get() != op->getResult(destId)) |
329 | continue; |
330 | setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], |
331 | newOp->getResult(destId), i - stages[op] + 1); |
332 | } |
333 | } |
334 | } |
335 | } |
336 | } |
337 | |
338 | llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> |
339 | LoopPipelinerInternal::analyzeCrossStageValues() { |
340 | llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues; |
341 | for (Operation *op : opOrder) { |
342 | unsigned stage = stages[op]; |
343 | |
344 | auto analyzeOperand = [&](OpOperand &operand) { |
345 | auto [def, distance] = getDefiningOpAndDistance(value: operand.get()); |
346 | if (!def) |
347 | return; |
348 | auto defStage = stages.find(Val: def); |
349 | if (defStage == stages.end() || defStage->second == stage || |
350 | defStage->second == stage + distance) |
351 | return; |
352 | assert(stage > defStage->second); |
353 | LiverangeInfo &info = crossStageValues[operand.get()]; |
354 | info.defStage = defStage->second; |
355 | info.lastUseStage = std::max(a: info.lastUseStage, b: stage); |
356 | }; |
357 | |
358 | for (OpOperand &operand : op->getOpOperands()) |
359 | analyzeOperand(operand); |
360 | visitUsedValuesDefinedAbove(regions: op->getRegions(), callback: [&](OpOperand *operand) { |
361 | analyzeOperand(*operand); |
362 | }); |
363 | } |
364 | return crossStageValues; |
365 | } |
366 | |
367 | std::pair<Operation *, int64_t> |
368 | LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { |
369 | int64_t distance = 0; |
370 | if (auto arg = dyn_cast<BlockArgument>(Val&: value)) { |
371 | if (arg.getOwner() != forOp.getBody()) |
372 | return {nullptr, 0}; |
373 | // Ignore induction variable. |
374 | if (arg.getArgNumber() == 0) |
375 | return {nullptr, 0}; |
376 | distance++; |
377 | value = |
378 | forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); |
379 | } |
380 | Operation *def = value.getDefiningOp(); |
381 | if (!def) |
382 | return {nullptr, 0}; |
383 | return {def, distance}; |
384 | } |
385 | |
386 | scf::ForOp LoopPipelinerInternal::createKernelLoop( |
387 | const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> |
388 | &crossStageValues, |
389 | RewriterBase &rewriter, |
390 | llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) { |
391 | // Creates the list of initial values associated to values used across |
392 | // stages. The initial values come from the prologue created above. |
393 | // Keep track of the kernel argument associated to each version of the |
394 | // values passed to the kernel. |
395 | llvm::SmallVector<Value> newLoopArg; |
396 | // For existing loop argument initialize them with the right version from the |
397 | // prologue. |
398 | for (const auto &retVal : |
399 | llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { |
400 | Operation *def = retVal.value().getDefiningOp(); |
401 | assert(def && "Only support loop carried dependencies of distance of 1 or " |
402 | "outside the loop" ); |
403 | auto defStage = stages.find(def); |
404 | if (defStage != stages.end()) { |
405 | Value valueVersion = |
406 | valueMapping[forOp.getRegionIterArgs()[retVal.index()]] |
407 | [maxStage - defStage->second]; |
408 | assert(valueVersion); |
409 | newLoopArg.push_back(valueVersion); |
410 | } else |
411 | newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); |
412 | } |
413 | for (auto escape : crossStageValues) { |
414 | LiverangeInfo &info = escape.second; |
415 | Value value = escape.first; |
416 | for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; |
417 | stageIdx++) { |
418 | Value valueVersion = |
419 | valueMapping[value][maxStage - info.lastUseStage + stageIdx]; |
420 | assert(valueVersion); |
421 | newLoopArg.push_back(Elt: valueVersion); |
422 | loopArgMap[std::make_pair(x&: value, y: info.lastUseStage - info.defStage - |
423 | stageIdx)] = newLoopArg.size() - 1; |
424 | } |
425 | } |
426 | |
427 | // Create the new kernel loop. When we peel the epilgue we need to peel |
428 | // `numStages - 1` iterations. Then we adjust the upper bound to remove those |
429 | // iterations. |
430 | Value newUb = forOp.getUpperBound(); |
431 | if (peelEpilogue) { |
432 | Type t = ub.getType(); |
433 | Location loc = forOp.getLoc(); |
434 | // newUb = ub - maxStage * step |
435 | Value maxStageValue = rewriter.create<arith::ConstantOp>( |
436 | loc, rewriter.getIntegerAttr(t, maxStage)); |
437 | Value maxStageByStep = |
438 | rewriter.create<arith::MulIOp>(loc, step, maxStageValue); |
439 | newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep); |
440 | } |
441 | auto newForOp = |
442 | rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb, |
443 | forOp.getStep(), newLoopArg); |
444 | // When there are no iter args, the loop body terminator will be created. |
445 | // Since we always create it below, remove the terminator if it was created. |
446 | if (!newForOp.getBody()->empty()) |
447 | rewriter.eraseOp(op: newForOp.getBody()->getTerminator()); |
448 | return newForOp; |
449 | } |
450 | |
451 | LogicalResult LoopPipelinerInternal::createKernel( |
452 | scf::ForOp newForOp, |
453 | const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> |
454 | &crossStageValues, |
455 | const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, |
456 | RewriterBase &rewriter) { |
457 | valueMapping.clear(); |
458 | |
459 | // Create the kernel, we clone instruction based on the order given by |
460 | // user and remap operands coming from a previous stages. |
461 | rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); |
462 | IRMapping mapping; |
463 | mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); |
464 | for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { |
465 | mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); |
466 | } |
467 | SmallVector<Value> predicates(maxStage + 1, nullptr); |
468 | if (!peelEpilogue) { |
469 | // Create a predicate for each stage except the last stage. |
470 | Location loc = newForOp.getLoc(); |
471 | Type t = ub.getType(); |
472 | for (unsigned i = 0; i < maxStage; i++) { |
473 | // c = ub - (maxStage - i) * step |
474 | Value c = rewriter.create<arith::SubIOp>( |
475 | loc, ub, |
476 | rewriter.create<arith::MulIOp>( |
477 | loc, step, |
478 | rewriter.create<arith::ConstantOp>( |
479 | loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); |
480 | |
481 | Value pred = rewriter.create<arith::CmpIOp>( |
482 | newForOp.getLoc(), arith::CmpIPredicate::slt, |
483 | newForOp.getInductionVar(), c); |
484 | predicates[i] = pred; |
485 | } |
486 | } |
487 | for (Operation *op : opOrder) { |
488 | int64_t useStage = stages[op]; |
489 | auto *newOp = rewriter.clone(op&: *op, mapper&: mapping); |
490 | SmallVector<OpOperand *> operands; |
491 | // Collect all the operands for the cloned op and its nested ops. |
492 | op->walk(callback: [&operands](Operation *nestedOp) { |
493 | for (OpOperand &operand : nestedOp->getOpOperands()) { |
494 | operands.push_back(Elt: &operand); |
495 | } |
496 | }); |
497 | for (OpOperand *operand : operands) { |
498 | Operation *nestedNewOp = mapping.lookup(from: operand->getOwner()); |
499 | // Special case for the induction variable uses. We replace it with a |
500 | // version incremented based on the stage where it is used. |
501 | if (operand->get() == forOp.getInductionVar()) { |
502 | rewriter.setInsertionPoint(newOp); |
503 | |
504 | // offset = (maxStage - stages[op]) * step |
505 | Type t = step.getType(); |
506 | Value offset = rewriter.create<arith::MulIOp>( |
507 | forOp.getLoc(), step, |
508 | rewriter.create<arith::ConstantOp>( |
509 | forOp.getLoc(), |
510 | rewriter.getIntegerAttr(t, maxStage - stages[op]))); |
511 | Value iv = rewriter.create<arith::AddIOp>( |
512 | forOp.getLoc(), newForOp.getInductionVar(), offset); |
513 | nestedNewOp->setOperand(idx: operand->getOperandNumber(), value: iv); |
514 | rewriter.setInsertionPointAfter(newOp); |
515 | continue; |
516 | } |
517 | Value source = operand->get(); |
518 | auto arg = dyn_cast<BlockArgument>(Val&: source); |
519 | if (arg && arg.getOwner() == forOp.getBody()) { |
520 | Value ret = forOp.getBody()->getTerminator()->getOperand( |
521 | arg.getArgNumber() - 1); |
522 | Operation *dep = ret.getDefiningOp(); |
523 | if (!dep) |
524 | continue; |
525 | auto stageDep = stages.find(Val: dep); |
526 | if (stageDep == stages.end() || stageDep->second == useStage) |
527 | continue; |
528 | // If the value is a loop carried value coming from stage N + 1 remap, |
529 | // it will become a direct use. |
530 | if (stageDep->second == useStage + 1) { |
531 | nestedNewOp->setOperand(idx: operand->getOperandNumber(), |
532 | value: mapping.lookupOrDefault(from: ret)); |
533 | continue; |
534 | } |
535 | source = ret; |
536 | } |
537 | // For operands defined in a previous stage we need to remap it to use |
538 | // the correct region argument. We look for the right version of the |
539 | // Value based on the stage where it is used. |
540 | Operation *def = source.getDefiningOp(); |
541 | if (!def) |
542 | continue; |
543 | auto stageDef = stages.find(Val: def); |
544 | if (stageDef == stages.end() || stageDef->second == useStage) |
545 | continue; |
546 | auto remap = loopArgMap.find( |
547 | Val: std::make_pair(x: operand->get(), y: useStage - stageDef->second)); |
548 | assert(remap != loopArgMap.end()); |
549 | nestedNewOp->setOperand(idx: operand->getOperandNumber(), |
550 | value: newForOp.getRegionIterArgs()[remap->second]); |
551 | } |
552 | |
553 | if (predicates[useStage]) { |
554 | newOp = predicateFn(rewriter, newOp, predicates[useStage]); |
555 | if (!newOp) |
556 | return failure(); |
557 | // Remap the results to the new predicated one. |
558 | for (auto values : llvm::zip(t: op->getResults(), u: newOp->getResults())) |
559 | mapping.map(from: std::get<0>(t&: values), to: std::get<1>(t&: values)); |
560 | } |
561 | rewriter.setInsertionPointAfter(newOp); |
562 | if (annotateFn) |
563 | annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0); |
564 | } |
565 | |
566 | // Collect the Values that need to be returned by the forOp. For each |
567 | // value we need to have `LastUseStage - DefStage` number of versions |
568 | // returned. |
569 | // We create a mapping between original values and the associated loop |
570 | // returned values that will be needed by the epilogue. |
571 | llvm::SmallVector<Value> yieldOperands; |
572 | for (OpOperand &yieldOperand : |
573 | forOp.getBody()->getTerminator()->getOpOperands()) { |
574 | Value source = mapping.lookupOrDefault(yieldOperand.get()); |
575 | // When we don't peel the epilogue and the yield value is used outside the |
576 | // loop we need to make sure we return the version from numStages - |
577 | // defStage. |
578 | if (!peelEpilogue && |
579 | !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { |
580 | Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; |
581 | if (def) { |
582 | auto defStage = stages.find(def); |
583 | if (defStage != stages.end() && defStage->second < maxStage) { |
584 | Value pred = predicates[defStage->second]; |
585 | source = rewriter.create<arith::SelectOp>( |
586 | pred.getLoc(), pred, source, |
587 | newForOp.getBody() |
588 | ->getArguments()[yieldOperand.getOperandNumber() + 1]); |
589 | } |
590 | } |
591 | } |
592 | yieldOperands.push_back(source); |
593 | } |
594 | |
595 | for (auto &it : crossStageValues) { |
596 | int64_t version = maxStage - it.second.lastUseStage + 1; |
597 | unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; |
598 | // add the original version to yield ops. |
599 | // If there is a live range spanning across more than 2 stages we need to |
600 | // add extra arg. |
601 | for (unsigned i = 1; i < numVersionReturned; i++) { |
602 | setValueMapping(key: it.first, el: newForOp->getResult(yieldOperands.size()), |
603 | idx: version++); |
604 | yieldOperands.push_back( |
605 | Elt: newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + |
606 | newForOp.getNumInductionVars()]); |
607 | } |
608 | setValueMapping(key: it.first, el: newForOp->getResult(yieldOperands.size()), |
609 | idx: version++); |
610 | yieldOperands.push_back(Elt: mapping.lookupOrDefault(from: it.first)); |
611 | } |
612 | // Map the yield operand to the forOp returned value. |
613 | for (const auto &retVal : |
614 | llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { |
615 | Operation *def = retVal.value().getDefiningOp(); |
616 | assert(def && "Only support loop carried dependencies of distance of 1 or " |
617 | "defined outside the loop" ); |
618 | auto defStage = stages.find(def); |
619 | if (defStage == stages.end()) { |
620 | for (unsigned int stage = 1; stage <= maxStage; stage++) |
621 | setValueMapping(forOp.getRegionIterArgs()[retVal.index()], |
622 | retVal.value(), stage); |
623 | } else if (defStage->second > 0) { |
624 | setValueMapping(forOp.getRegionIterArgs()[retVal.index()], |
625 | newForOp->getResult(retVal.index()), |
626 | maxStage - defStage->second + 1); |
627 | } |
628 | } |
629 | rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands); |
630 | return success(); |
631 | } |
632 | |
633 | void LoopPipelinerInternal::emitEpilogue( |
634 | RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) { |
635 | // Emit different versions of the induction variable. They will be |
636 | // removed by dead code if not used. |
637 | for (int64_t i = 0; i < maxStage; i++) { |
638 | Location loc = forOp.getLoc(); |
639 | Type t = lb.getType(); |
640 | Value minusOne = |
641 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1)); |
642 | // number of iterations = ((ub - 1) - lb) / step |
643 | Value totalNumIteration = rewriter.create<arith::DivUIOp>( |
644 | loc, |
645 | rewriter.create<arith::SubIOp>( |
646 | loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb), |
647 | step); |
648 | // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) |
649 | Value minusI = |
650 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i)); |
651 | Value newlastIter = rewriter.create<arith::AddIOp>( |
652 | loc, lb, |
653 | rewriter.create<arith::MulIOp>( |
654 | loc, step, |
655 | rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI))); |
656 | setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); |
657 | } |
658 | // Emit `maxStage - 1` epilogue part that includes operations from stages |
659 | // [i; maxStage]. |
660 | for (int64_t i = 1; i <= maxStage; i++) { |
661 | for (Operation *op : opOrder) { |
662 | if (stages[op] < i) |
663 | continue; |
664 | Operation *newOp = |
665 | cloneAndUpdateOperands(rewriter, op, callback: [&](OpOperand *newOperand) { |
666 | auto it = valueMapping.find(Val: newOperand->get()); |
667 | if (it != valueMapping.end()) { |
668 | Value replacement = it->second[maxStage - stages[op] + i]; |
669 | newOperand->set(replacement); |
670 | } |
671 | }); |
672 | if (annotateFn) |
673 | annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); |
674 | for (unsigned destId : llvm::seq(Begin: unsigned(0), End: op->getNumResults())) { |
675 | setValueMapping(key: op->getResult(idx: destId), el: newOp->getResult(idx: destId), |
676 | idx: maxStage - stages[op] + i); |
677 | // If the value is a loop carried dependency update the loop argument |
678 | // mapping and keep track of the last version to replace the original |
679 | // forOp uses. |
680 | for (OpOperand &operand : |
681 | forOp.getBody()->getTerminator()->getOpOperands()) { |
682 | if (operand.get() != op->getResult(destId)) |
683 | continue; |
684 | unsigned version = maxStage - stages[op] + i + 1; |
685 | // If the version is greater than maxStage it means it maps to the |
686 | // original forOp returned value. |
687 | if (version > maxStage) { |
688 | returnValues[operand.getOperandNumber()] = newOp->getResult(destId); |
689 | continue; |
690 | } |
691 | setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], |
692 | newOp->getResult(destId), version); |
693 | } |
694 | } |
695 | } |
696 | } |
697 | } |
698 | |
699 | void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { |
700 | auto it = valueMapping.find(Val: key); |
701 | // If the value is not in the map yet add a vector big enough to store all |
702 | // versions. |
703 | if (it == valueMapping.end()) |
704 | it = |
705 | valueMapping |
706 | .insert(KV: std::make_pair(x&: key, y: llvm::SmallVector<Value>(maxStage + 1))) |
707 | .first; |
708 | it->second[idx] = el; |
709 | } |
710 | |
711 | } // namespace |
712 | |
713 | FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, |
714 | const PipeliningOption &options, |
715 | bool *modifiedIR) { |
716 | if (modifiedIR) |
717 | *modifiedIR = false; |
718 | LoopPipelinerInternal pipeliner; |
719 | if (!pipeliner.initializeLoopInfo(op: forOp, options)) |
720 | return failure(); |
721 | |
722 | if (modifiedIR) |
723 | *modifiedIR = true; |
724 | |
725 | // 1. Emit prologue. |
726 | pipeliner.emitPrologue(rewriter); |
727 | |
728 | // 2. Track values used across stages. When a value cross stages it will |
729 | // need to be passed as loop iteration arguments. |
730 | // We first collect the values that are used in a different stage than where |
731 | // they are defined. |
732 | llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> |
733 | crossStageValues = pipeliner.analyzeCrossStageValues(); |
734 | |
735 | // Mapping between original loop values used cross stage and the block |
736 | // arguments associated after pipelining. A Value may map to several |
737 | // arguments if its liverange spans across more than 2 stages. |
738 | llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap; |
739 | // 3. Create the new kernel loop and return the block arguments mapping. |
740 | ForOp newForOp = |
741 | pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); |
742 | // Create the kernel block, order ops based on user choice and remap |
743 | // operands. |
744 | if (failed(pipeliner.createKernel(newForOp: newForOp, crossStageValues, loopArgMap, |
745 | rewriter))) |
746 | return failure(); |
747 | |
748 | llvm::SmallVector<Value> returnValues = |
749 | newForOp.getResults().take_front(forOp->getNumResults()); |
750 | if (options.peelEpilogue) { |
751 | // 4. Emit the epilogue after the new forOp. |
752 | rewriter.setInsertionPointAfter(newForOp); |
753 | pipeliner.emitEpilogue(rewriter, returnValues); |
754 | } |
755 | // 5. Erase the original loop and replace the uses with the epilogue output. |
756 | if (forOp->getNumResults() > 0) |
757 | rewriter.replaceOp(forOp, returnValues); |
758 | else |
759 | rewriter.eraseOp(op: forOp); |
760 | |
761 | return newForOp; |
762 | } |
763 | |
764 | void mlir::scf::populateSCFLoopPipeliningPatterns( |
765 | RewritePatternSet &patterns, const PipeliningOption &options) { |
766 | patterns.add<ForLoopPipeliningPattern>(arg: options, args: patterns.getContext()); |
767 | } |
768 | |