1 | //===- Utils.cpp ---- Misc utilities for loop transformation ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements miscellaneous loop transformation routines. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
14 | #include "mlir/Analysis/SliceAnalysis.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/SCF/IR/SCF.h" |
19 | #include "mlir/IR/BuiltinOps.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/PatternMatch.h" |
22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | #include "mlir/Support/MathExtras.h" |
24 | #include "mlir/Transforms/RegionUtils.h" |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "llvm/ADT/SetVector.h" |
27 | #include "llvm/ADT/SmallPtrSet.h" |
28 | #include "llvm/ADT/SmallVector.h" |
29 | |
30 | using namespace mlir; |
31 | |
32 | namespace { |
33 | // This structure is to pass and return sets of loop parameters without |
34 | // confusing the order. |
35 | struct LoopParams { |
36 | Value lowerBound; |
37 | Value upperBound; |
38 | Value step; |
39 | }; |
40 | } // namespace |
41 | |
42 | SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields( |
43 | RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest, |
44 | ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, |
45 | bool replaceIterOperandsUsesInLoop) { |
46 | if (loopNest.empty()) |
47 | return {}; |
48 | // This method is recursive (to make it more readable). Adding an |
49 | // assertion here to limit the recursion. (See |
50 | // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235) |
51 | assert(loopNest.size() <= 10 && |
52 | "exceeded recursion limit when yielding value from loop nest" ); |
53 | |
54 | // To yield a value from a perfectly nested loop nest, the following |
55 | // pattern needs to be created, i.e. starting with |
56 | // |
57 | // ```mlir |
58 | // scf.for .. { |
59 | // scf.for .. { |
60 | // scf.for .. { |
61 | // %value = ... |
62 | // } |
63 | // } |
64 | // } |
65 | // ``` |
66 | // |
67 | // needs to be modified to |
68 | // |
69 | // ```mlir |
70 | // %0 = scf.for .. iter_args(%arg0 = %init) { |
71 | // %1 = scf.for .. iter_args(%arg1 = %arg0) { |
72 | // %2 = scf.for .. iter_args(%arg2 = %arg1) { |
73 | // %value = ... |
74 | // scf.yield %value |
75 | // } |
76 | // scf.yield %2 |
77 | // } |
78 | // scf.yield %1 |
79 | // } |
80 | // ``` |
81 | // |
82 | // The inner most loop is handled using the `replaceWithAdditionalYields` |
83 | // that works on a single loop. |
84 | if (loopNest.size() == 1) { |
85 | auto innerMostLoop = |
86 | cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields( |
87 | rewriter, newIterOperands, replaceIterOperandsUsesInLoop, |
88 | newYieldValuesFn)); |
89 | return {innerMostLoop}; |
90 | } |
91 | // The outer loops are modified by calling this method recursively |
92 | // - The return value of the inner loop is the value yielded by this loop. |
93 | // - The region iter args of this loop are the init_args for the inner loop. |
94 | SmallVector<scf::ForOp> newLoopNest; |
95 | NewYieldValuesFn fn = |
96 | [&](OpBuilder &innerBuilder, Location loc, |
97 | ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> { |
98 | newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(), |
99 | innerNewBBArgs, newYieldValuesFn, |
100 | replaceIterOperandsUsesInLoop); |
101 | return llvm::to_vector(llvm::map_range( |
102 | newLoopNest.front().getResults().take_back(innerNewBBArgs.size()), |
103 | [](OpResult r) -> Value { return r; })); |
104 | }; |
105 | scf::ForOp outerMostLoop = |
106 | cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields( |
107 | rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn)); |
108 | newLoopNest.insert(newLoopNest.begin(), outerMostLoop); |
109 | return newLoopNest; |
110 | } |
111 | |
112 | /// Outline a region with a single block into a new FuncOp. |
113 | /// Assumes the FuncOp result types is the type of the yielded operands of the |
114 | /// single block. This constraint makes it easy to determine the result. |
115 | /// This method also clones the `arith::ConstantIndexOp` at the start of |
116 | /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is |
117 | /// provided, it will be set to point to the operation that calls the outlined |
118 | /// function. |
119 | // TODO: support more than single-block regions. |
120 | // TODO: more flexible constant handling. |
121 | FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter, |
122 | Location loc, |
123 | Region ®ion, |
124 | StringRef funcName, |
125 | func::CallOp *callOp) { |
126 | assert(!funcName.empty() && "funcName cannot be empty" ); |
127 | if (!region.hasOneBlock()) |
128 | return failure(); |
129 | |
130 | Block *originalBlock = ®ion.front(); |
131 | Operation *originalTerminator = originalBlock->getTerminator(); |
132 | |
133 | // Outline before current function. |
134 | OpBuilder::InsertionGuard g(rewriter); |
135 | rewriter.setInsertionPoint(region.getParentOfType<func::FuncOp>()); |
136 | |
137 | SetVector<Value> captures; |
138 | getUsedValuesDefinedAbove(regions: region, values&: captures); |
139 | |
140 | ValueRange outlinedValues(captures.getArrayRef()); |
141 | SmallVector<Type> outlinedFuncArgTypes; |
142 | SmallVector<Location> outlinedFuncArgLocs; |
143 | // Region's arguments are exactly the first block's arguments as per |
144 | // Region::getArguments(). |
145 | // Func's arguments are cat(regions's arguments, captures arguments). |
146 | for (BlockArgument arg : region.getArguments()) { |
147 | outlinedFuncArgTypes.push_back(Elt: arg.getType()); |
148 | outlinedFuncArgLocs.push_back(Elt: arg.getLoc()); |
149 | } |
150 | for (Value value : outlinedValues) { |
151 | outlinedFuncArgTypes.push_back(Elt: value.getType()); |
152 | outlinedFuncArgLocs.push_back(Elt: value.getLoc()); |
153 | } |
154 | FunctionType outlinedFuncType = |
155 | FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes, |
156 | originalTerminator->getOperandTypes()); |
157 | auto outlinedFunc = |
158 | rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType); |
159 | Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); |
160 | |
161 | // Merge blocks while replacing the original block operands. |
162 | // Warning: `mergeBlocks` erases the original block, reconstruct it later. |
163 | int64_t numOriginalBlockArguments = originalBlock->getNumArguments(); |
164 | auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments(); |
165 | { |
166 | OpBuilder::InsertionGuard g(rewriter); |
167 | rewriter.setInsertionPointToEnd(outlinedFuncBody); |
168 | rewriter.mergeBlocks( |
169 | source: originalBlock, dest: outlinedFuncBody, |
170 | argValues: outlinedFuncBlockArgs.take_front(numOriginalBlockArguments)); |
171 | // Explicitly set up a new ReturnOp terminator. |
172 | rewriter.setInsertionPointToEnd(outlinedFuncBody); |
173 | rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(), |
174 | originalTerminator->getOperands()); |
175 | } |
176 | |
177 | // Reconstruct the block that was deleted and add a |
178 | // terminator(call_results). |
179 | Block *newBlock = rewriter.createBlock( |
180 | parent: ®ion, insertPt: region.begin(), |
181 | argTypes: TypeRange{outlinedFuncArgTypes}.take_front(n: numOriginalBlockArguments), |
182 | locs: ArrayRef<Location>(outlinedFuncArgLocs) |
183 | .take_front(N: numOriginalBlockArguments)); |
184 | { |
185 | OpBuilder::InsertionGuard g(rewriter); |
186 | rewriter.setInsertionPointToEnd(newBlock); |
187 | SmallVector<Value> callValues; |
188 | llvm::append_range(C&: callValues, R: newBlock->getArguments()); |
189 | llvm::append_range(C&: callValues, R&: outlinedValues); |
190 | auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues); |
191 | if (callOp) |
192 | *callOp = call; |
193 | |
194 | // `originalTerminator` was moved to `outlinedFuncBody` and is still valid. |
195 | // Clone `originalTerminator` to take the callOp results then erase it from |
196 | // `outlinedFuncBody`. |
197 | IRMapping bvm; |
198 | bvm.map(originalTerminator->getOperands(), call->getResults()); |
199 | rewriter.clone(op&: *originalTerminator, mapper&: bvm); |
200 | rewriter.eraseOp(op: originalTerminator); |
201 | } |
202 | |
203 | // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`. |
204 | // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`. |
205 | for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back( |
206 | outlinedValues.size()))) { |
207 | Value orig = std::get<0>(it); |
208 | Value repl = std::get<1>(it); |
209 | { |
210 | OpBuilder::InsertionGuard g(rewriter); |
211 | rewriter.setInsertionPointToStart(outlinedFuncBody); |
212 | if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) { |
213 | IRMapping bvm; |
214 | repl = rewriter.clone(*cst, bvm)->getResult(0); |
215 | } |
216 | } |
217 | orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) { |
218 | return outlinedFunc->isProperAncestor(opOperand.getOwner()); |
219 | }); |
220 | } |
221 | |
222 | return outlinedFunc; |
223 | } |
224 | |
225 | LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp, |
226 | func::FuncOp *thenFn, StringRef thenFnName, |
227 | func::FuncOp *elseFn, StringRef elseFnName) { |
228 | IRRewriter rewriter(b); |
229 | Location loc = ifOp.getLoc(); |
230 | FailureOr<func::FuncOp> outlinedFuncOpOrFailure; |
231 | if (thenFn && !ifOp.getThenRegion().empty()) { |
232 | outlinedFuncOpOrFailure = outlineSingleBlockRegion( |
233 | rewriter, loc, ifOp.getThenRegion(), thenFnName); |
234 | if (failed(result: outlinedFuncOpOrFailure)) |
235 | return failure(); |
236 | *thenFn = *outlinedFuncOpOrFailure; |
237 | } |
238 | if (elseFn && !ifOp.getElseRegion().empty()) { |
239 | outlinedFuncOpOrFailure = outlineSingleBlockRegion( |
240 | rewriter, loc, ifOp.getElseRegion(), elseFnName); |
241 | if (failed(result: outlinedFuncOpOrFailure)) |
242 | return failure(); |
243 | *elseFn = *outlinedFuncOpOrFailure; |
244 | } |
245 | return success(); |
246 | } |
247 | |
248 | bool mlir::getInnermostParallelLoops(Operation *rootOp, |
249 | SmallVectorImpl<scf::ParallelOp> &result) { |
250 | assert(rootOp != nullptr && "Root operation must not be a nullptr." ); |
251 | bool rootEnclosesPloops = false; |
252 | for (Region ®ion : rootOp->getRegions()) { |
253 | for (Block &block : region.getBlocks()) { |
254 | for (Operation &op : block) { |
255 | bool enclosesPloops = getInnermostParallelLoops(&op, result); |
256 | rootEnclosesPloops |= enclosesPloops; |
257 | if (auto ploop = dyn_cast<scf::ParallelOp>(op)) { |
258 | rootEnclosesPloops = true; |
259 | |
260 | // Collect parallel loop if it is an innermost one. |
261 | if (!enclosesPloops) |
262 | result.push_back(ploop); |
263 | } |
264 | } |
265 | } |
266 | } |
267 | return rootEnclosesPloops; |
268 | } |
269 | |
270 | // Build the IR that performs ceil division of a positive value by a constant: |
271 | // ceildiv(a, B) = divis(a + (B-1), B) |
272 | // where divis is rounding-to-zero division. |
273 | static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, |
274 | int64_t divisor) { |
275 | assert(divisor > 0 && "expected positive divisor" ); |
276 | assert(dividend.getType().isIndex() && "expected index-typed value" ); |
277 | |
278 | Value divisorMinusOneCst = |
279 | builder.create<arith::ConstantIndexOp>(location: loc, args: divisor - 1); |
280 | Value divisorCst = builder.create<arith::ConstantIndexOp>(location: loc, args&: divisor); |
281 | Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst); |
282 | return builder.create<arith::DivUIOp>(loc, sum, divisorCst); |
283 | } |
284 | |
285 | // Build the IR that performs ceil division of a positive value by another |
286 | // positive value: |
287 | // ceildiv(a, b) = divis(a + (b - 1), b) |
288 | // where divis is rounding-to-zero division. |
289 | static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, |
290 | Value divisor) { |
291 | assert(dividend.getType().isIndex() && "expected index-typed value" ); |
292 | |
293 | Value cstOne = builder.create<arith::ConstantIndexOp>(location: loc, args: 1); |
294 | Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne); |
295 | Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne); |
296 | return builder.create<arith::DivUIOp>(loc, sum, divisor); |
297 | } |
298 | |
299 | /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with |
300 | /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap |
301 | /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each |
302 | /// unrolled iteration using annotateFn. |
303 | static void generateUnrolledLoop( |
304 | Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, |
305 | function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn, |
306 | function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn, |
307 | ValueRange iterArgs, ValueRange yieldedValues) { |
308 | // Builder to insert unrolled bodies just before the terminator of the body of |
309 | // 'forOp'. |
310 | auto builder = OpBuilder::atBlockTerminator(block: loopBodyBlock); |
311 | |
312 | if (!annotateFn) |
313 | annotateFn = [](unsigned, Operation *, OpBuilder) {}; |
314 | |
315 | // Keep a pointer to the last non-terminator operation in the original block |
316 | // so that we know what to clone (since we are doing this in-place). |
317 | Block::iterator srcBlockEnd = std::prev(x: loopBodyBlock->end(), n: 2); |
318 | |
319 | // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). |
320 | SmallVector<Value, 4> lastYielded(yieldedValues); |
321 | |
322 | for (unsigned i = 1; i < unrollFactor; i++) { |
323 | IRMapping operandMap; |
324 | |
325 | // Prepare operand map. |
326 | operandMap.map(from&: iterArgs, to&: lastYielded); |
327 | |
328 | // If the induction variable is used, create a remapping to the value for |
329 | // this unrolled instance. |
330 | if (!forOpIV.use_empty()) { |
331 | Value ivUnroll = ivRemapFn(i, forOpIV, builder); |
332 | operandMap.map(from: forOpIV, to: ivUnroll); |
333 | } |
334 | |
335 | // Clone the original body of 'forOp'. |
336 | for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++) { |
337 | Operation *clonedOp = builder.clone(op&: *it, mapper&: operandMap); |
338 | annotateFn(i, clonedOp, builder); |
339 | } |
340 | |
341 | // Update yielded values. |
342 | for (unsigned i = 0, e = lastYielded.size(); i < e; i++) |
343 | lastYielded[i] = operandMap.lookup(from: yieldedValues[i]); |
344 | } |
345 | |
346 | // Make sure we annotate the Ops in the original body. We do this last so that |
347 | // any annotations are not copied into the cloned Ops above. |
348 | for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++) |
349 | annotateFn(0, &*it, builder); |
350 | |
351 | // Update operands of the yield statement. |
352 | loopBodyBlock->getTerminator()->setOperands(lastYielded); |
353 | } |
354 | |
355 | /// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled. |
356 | LogicalResult mlir::loopUnrollByFactor( |
357 | scf::ForOp forOp, uint64_t unrollFactor, |
358 | function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) { |
359 | assert(unrollFactor > 0 && "expected positive unroll factor" ); |
360 | |
361 | // Return if the loop body is empty. |
362 | if (llvm::hasSingleElement(forOp.getBody()->getOperations())) |
363 | return success(); |
364 | |
365 | // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate |
366 | // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. |
367 | OpBuilder boundsBuilder(forOp); |
368 | IRRewriter rewriter(forOp.getContext()); |
369 | auto loc = forOp.getLoc(); |
370 | Value step = forOp.getStep(); |
371 | Value upperBoundUnrolled; |
372 | Value stepUnrolled; |
373 | bool generateEpilogueLoop = true; |
374 | |
375 | std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound()); |
376 | std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound()); |
377 | std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep()); |
378 | if (lbCstOp && ubCstOp && stepCstOp) { |
379 | // Constant loop bounds computation. |
380 | int64_t lbCst = lbCstOp.value(); |
381 | int64_t ubCst = ubCstOp.value(); |
382 | int64_t stepCst = stepCstOp.value(); |
383 | assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 && |
384 | "expected positive loop bounds and step" ); |
385 | int64_t tripCount = mlir::ceilDiv(lhs: ubCst - lbCst, rhs: stepCst); |
386 | |
387 | if (unrollFactor == 1) { |
388 | if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter))) |
389 | return failure(); |
390 | return success(); |
391 | } |
392 | |
393 | int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor); |
394 | int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; |
395 | int64_t stepUnrolledCst = stepCst * unrollFactor; |
396 | |
397 | // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. |
398 | generateEpilogueLoop = upperBoundUnrolledCst < ubCst; |
399 | if (generateEpilogueLoop) |
400 | upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>( |
401 | loc, upperBoundUnrolledCst); |
402 | else |
403 | upperBoundUnrolled = forOp.getUpperBound(); |
404 | |
405 | // Create constant for 'stepUnrolled'. |
406 | stepUnrolled = stepCst == stepUnrolledCst |
407 | ? step |
408 | : boundsBuilder.create<arith::ConstantIndexOp>( |
409 | loc, stepUnrolledCst); |
410 | } else { |
411 | // Dynamic loop bounds computation. |
412 | // TODO: Add dynamic asserts for negative lb/ub/step, or |
413 | // consider using ceilDiv from AffineApplyExpander. |
414 | auto lowerBound = forOp.getLowerBound(); |
415 | auto upperBound = forOp.getUpperBound(); |
416 | Value diff = |
417 | boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound); |
418 | Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); |
419 | Value unrollFactorCst = |
420 | boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor); |
421 | Value tripCountRem = |
422 | boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst); |
423 | // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) |
424 | Value tripCountEvenMultiple = |
425 | boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem); |
426 | // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step |
427 | upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>( |
428 | loc, lowerBound, |
429 | boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step)); |
430 | // Scale 'step' by 'unrollFactor'. |
431 | stepUnrolled = |
432 | boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst); |
433 | } |
434 | |
435 | // Create epilogue clean up loop starting at 'upperBoundUnrolled'. |
436 | if (generateEpilogueLoop) { |
437 | OpBuilder epilogueBuilder(forOp->getContext()); |
438 | epilogueBuilder.setInsertionPoint(forOp->getBlock(), |
439 | std::next(x: Block::iterator(forOp))); |
440 | auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp)); |
441 | epilogueForOp.setLowerBound(upperBoundUnrolled); |
442 | |
443 | // Update uses of loop results. |
444 | auto results = forOp.getResults(); |
445 | auto epilogueResults = epilogueForOp.getResults(); |
446 | |
447 | for (auto e : llvm::zip(results, epilogueResults)) { |
448 | std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); |
449 | } |
450 | epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), |
451 | epilogueForOp.getInitArgs().size(), results); |
452 | (void)epilogueForOp.promoteIfSingleIteration(rewriter); |
453 | } |
454 | |
455 | // Create unrolled loop. |
456 | forOp.setUpperBound(upperBoundUnrolled); |
457 | forOp.setStep(stepUnrolled); |
458 | |
459 | auto iterArgs = ValueRange(forOp.getRegionIterArgs()); |
460 | auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); |
461 | |
462 | generateUnrolledLoop( |
463 | forOp.getBody(), forOp.getInductionVar(), unrollFactor, |
464 | [&](unsigned i, Value iv, OpBuilder b) { |
465 | // iv' = iv + step * i; |
466 | auto stride = b.create<arith::MulIOp>( |
467 | loc, step, b.create<arith::ConstantIndexOp>(loc, i)); |
468 | return b.create<arith::AddIOp>(loc, iv, stride); |
469 | }, |
470 | annotateFn, iterArgs, yieldedValues); |
471 | // Promote the loop body up if this has turned into a single iteration loop. |
472 | (void)forOp.promoteIfSingleIteration(rewriter); |
473 | return success(); |
474 | } |
475 | |
476 | /// Transform a loop with a strictly positive step |
477 | /// for %i = %lb to %ub step %s |
478 | /// into a 0-based loop with step 1 |
479 | /// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 { |
480 | /// %i = %ii * %s + %lb |
481 | /// Insert the induction variable remapping in the body of `inner`, which is |
482 | /// expected to be either `loop` or another loop perfectly nested under `loop`. |
483 | /// Insert the definition of new bounds immediate before `outer`, which is |
484 | /// expected to be either `loop` or its parent in the loop nest. |
485 | static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, |
486 | Value lb, Value ub, Value step) { |
487 | // For non-index types, generate `arith` instructions |
488 | // Check if the loop is already known to have a constant zero lower bound or |
489 | // a constant one step. |
490 | bool isZeroBased = false; |
491 | if (auto lbCst = getConstantIntValue(ofr: lb)) |
492 | isZeroBased = lbCst.value() == 0; |
493 | |
494 | bool isStepOne = false; |
495 | if (auto stepCst = getConstantIntValue(ofr: step)) |
496 | isStepOne = stepCst.value() == 1; |
497 | |
498 | // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) |
499 | // assuming the step is strictly positive. Update the bounds and the step |
500 | // of the loop to go from 0 to the number of iterations, if necessary. |
501 | if (isZeroBased && isStepOne) |
502 | return {.lowerBound: lb, .upperBound: ub, .step: step}; |
503 | |
504 | Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb); |
505 | Value newUpperBound = |
506 | isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step); |
507 | |
508 | Value newLowerBound = isZeroBased |
509 | ? lb |
510 | : rewriter.create<arith::ConstantOp>( |
511 | loc, rewriter.getZeroAttr(lb.getType())); |
512 | Value newStep = isStepOne |
513 | ? step |
514 | : rewriter.create<arith::ConstantOp>( |
515 | loc, rewriter.getIntegerAttr(step.getType(), 1)); |
516 | |
517 | return {.lowerBound: newLowerBound, .upperBound: newUpperBound, .step: newStep}; |
518 | } |
519 | |
520 | /// Get back the original induction variable values after loop normalization |
521 | static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, |
522 | Value normalizedIv, Value origLb, |
523 | Value origStep) { |
524 | Value denormalizedIv; |
525 | SmallPtrSet<Operation *, 2> preserve; |
526 | bool isStepOne = isConstantIntValue(ofr: origStep, value: 1); |
527 | bool isZeroBased = isConstantIntValue(ofr: origLb, value: 0); |
528 | |
529 | Value scaled = normalizedIv; |
530 | if (!isStepOne) { |
531 | scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep); |
532 | preserve.insert(Ptr: scaled.getDefiningOp()); |
533 | } |
534 | denormalizedIv = scaled; |
535 | if (!isZeroBased) { |
536 | denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb); |
537 | preserve.insert(Ptr: denormalizedIv.getDefiningOp()); |
538 | } |
539 | |
540 | rewriter.replaceAllUsesExcept(from: normalizedIv, to: denormalizedIv, preservedUsers: preserve); |
541 | } |
542 | |
543 | /// Helper function to multiply a sequence of values. |
544 | static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, |
545 | ArrayRef<Value> values) { |
546 | assert(!values.empty() && "unexpected empty list" ); |
547 | Value productOf = values.front(); |
548 | for (auto v : values.drop_front()) { |
549 | productOf = rewriter.create<arith::MulIOp>(loc, productOf, v); |
550 | } |
551 | return productOf; |
552 | } |
553 | |
554 | /// For each original loop, the value of the |
555 | /// induction variable can be obtained by dividing the induction variable of |
556 | /// the linearized loop by the total number of iterations of the loops nested |
557 | /// in it modulo the number of iterations in this loop (remove the values |
558 | /// related to the outer loops): |
559 | /// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i. |
560 | /// Compute these iteratively from the innermost loop by creating a "running |
561 | /// quotient" of division by the range. |
562 | static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>> |
563 | delinearizeInductionVariable(RewriterBase &rewriter, Location loc, |
564 | Value linearizedIv, ArrayRef<Value> ubs) { |
565 | Value previous = linearizedIv; |
566 | SmallVector<Value> delinearizedIvs(ubs.size()); |
567 | SmallPtrSet<Operation *, 2> preservedUsers; |
568 | for (unsigned i = 0, e = ubs.size(); i < e; ++i) { |
569 | unsigned idx = ubs.size() - i - 1; |
570 | if (i != 0) { |
571 | previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]); |
572 | preservedUsers.insert(Ptr: previous.getDefiningOp()); |
573 | } |
574 | Value iv = previous; |
575 | if (i != e - 1) { |
576 | iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]); |
577 | preservedUsers.insert(Ptr: iv.getDefiningOp()); |
578 | } |
579 | delinearizedIvs[idx] = iv; |
580 | } |
581 | return {delinearizedIvs, preservedUsers}; |
582 | } |
583 | |
584 | LogicalResult mlir::coalesceLoops(RewriterBase &rewriter, |
585 | MutableArrayRef<scf::ForOp> loops) { |
586 | if (loops.size() < 2) |
587 | return failure(); |
588 | |
589 | scf::ForOp innermost = loops.back(); |
590 | scf::ForOp outermost = loops.front(); |
591 | |
592 | // 1. Make sure all loops iterate from 0 to upperBound with step 1. This |
593 | // allows the following code to assume upperBound is the number of iterations. |
594 | for (auto loop : loops) { |
595 | OpBuilder::InsertionGuard g(rewriter); |
596 | rewriter.setInsertionPoint(outermost); |
597 | Value lb = loop.getLowerBound(); |
598 | Value ub = loop.getUpperBound(); |
599 | Value step = loop.getStep(); |
600 | auto newLoopParams = |
601 | emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step); |
602 | |
603 | rewriter.modifyOpInPlace(loop, [&]() { |
604 | loop.setLowerBound(newLoopParams.lowerBound); |
605 | loop.setUpperBound(newLoopParams.upperBound); |
606 | loop.setStep(newLoopParams.step); |
607 | }); |
608 | |
609 | rewriter.setInsertionPointToStart(innermost.getBody()); |
610 | denormalizeInductionVariable(rewriter, loop.getLoc(), |
611 | loop.getInductionVar(), lb, step); |
612 | } |
613 | |
614 | // 2. Emit code computing the upper bound of the coalesced loop as product |
615 | // of the number of iterations of all loops. |
616 | OpBuilder::InsertionGuard g(rewriter); |
617 | rewriter.setInsertionPoint(outermost); |
618 | Location loc = outermost.getLoc(); |
619 | SmallVector<Value> upperBounds = llvm::map_to_vector( |
620 | loops, [](auto loop) { return loop.getUpperBound(); }); |
621 | Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, values: upperBounds); |
622 | outermost.setUpperBound(upperBound); |
623 | |
624 | rewriter.setInsertionPointToStart(innermost.getBody()); |
625 | auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable( |
626 | rewriter, loc, outermost.getInductionVar(), upperBounds); |
627 | rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0], |
628 | preservedUsers); |
629 | |
630 | for (int i = loops.size() - 1; i > 0; --i) { |
631 | auto outerLoop = loops[i - 1]; |
632 | auto innerLoop = loops[i]; |
633 | |
634 | Operation *innerTerminator = innerLoop.getBody()->getTerminator(); |
635 | auto yieldedVals = llvm::to_vector(Range: innerTerminator->getOperands()); |
636 | rewriter.eraseOp(op: innerTerminator); |
637 | |
638 | SmallVector<Value> innerBlockArgs; |
639 | innerBlockArgs.push_back(Elt: delinearizeIvs[i]); |
640 | llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs()); |
641 | rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(), |
642 | Block::iterator(innerLoop), innerBlockArgs); |
643 | rewriter.replaceOp(innerLoop, yieldedVals); |
644 | } |
645 | return success(); |
646 | } |
647 | |
648 | LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) { |
649 | if (loops.empty()) { |
650 | return failure(); |
651 | } |
652 | IRRewriter rewriter(loops.front().getContext()); |
653 | return coalesceLoops(rewriter, loops); |
654 | } |
655 | |
656 | LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) { |
657 | LogicalResult result(failure()); |
658 | SmallVector<scf::ForOp> loops; |
659 | getPerfectlyNestedLoops(loops, op); |
660 | |
661 | // Look for a band of loops that can be coalesced, i.e. perfectly nested |
662 | // loops with bounds defined above some loop. |
663 | |
664 | // 1. For each loop, find above which parent loop its bounds operands are |
665 | // defined. |
666 | SmallVector<unsigned> operandsDefinedAbove(loops.size()); |
667 | for (unsigned i = 0, e = loops.size(); i < e; ++i) { |
668 | operandsDefinedAbove[i] = i; |
669 | for (unsigned j = 0; j < i; ++j) { |
670 | SmallVector<Value> boundsOperands = {loops[i].getLowerBound(), |
671 | loops[i].getUpperBound(), |
672 | loops[i].getStep()}; |
673 | if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) { |
674 | operandsDefinedAbove[i] = j; |
675 | break; |
676 | } |
677 | } |
678 | } |
679 | |
680 | // 2. For each inner loop check that the iter_args for the immediately outer |
681 | // loop are the init for the immediately inner loop and that the yields of the |
682 | // return of the inner loop is the yield for the immediately outer loop. Keep |
683 | // track of where the chain starts from for each loop. |
684 | SmallVector<unsigned> iterArgChainStart(loops.size()); |
685 | iterArgChainStart[0] = 0; |
686 | for (unsigned i = 1, e = loops.size(); i < e; ++i) { |
687 | // By default set the start of the chain to itself. |
688 | iterArgChainStart[i] = i; |
689 | auto outerloop = loops[i - 1]; |
690 | auto innerLoop = loops[i]; |
691 | if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) { |
692 | continue; |
693 | } |
694 | if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) { |
695 | continue; |
696 | } |
697 | auto outerloopTerminator = outerloop.getBody()->getTerminator(); |
698 | if (!llvm::equal(outerloopTerminator->getOperands(), |
699 | innerLoop.getResults())) { |
700 | continue; |
701 | } |
702 | iterArgChainStart[i] = iterArgChainStart[i - 1]; |
703 | } |
704 | |
705 | // 3. Identify bands of loops such that the operands of all of them are |
706 | // defined above the first loop in the band. Traverse the nest bottom-up |
707 | // so that modifications don't invalidate the inner loops. |
708 | for (unsigned end = loops.size(); end > 0; --end) { |
709 | unsigned start = 0; |
710 | for (; start < end - 1; ++start) { |
711 | auto maxPos = |
712 | *std::max_element(first: std::next(x: operandsDefinedAbove.begin(), n: start), |
713 | last: std::next(x: operandsDefinedAbove.begin(), n: end)); |
714 | if (maxPos > start) |
715 | continue; |
716 | if (iterArgChainStart[end - 1] > start) |
717 | continue; |
718 | auto band = llvm::MutableArrayRef(loops.data() + start, end - start); |
719 | if (succeeded(coalesceLoops(band))) |
720 | result = success(); |
721 | break; |
722 | } |
723 | // If a band was found and transformed, keep looking at the loops above |
724 | // the outermost transformed loop. |
725 | if (start != end - 1) |
726 | end = start + 1; |
727 | } |
728 | return result; |
729 | } |
730 | |
731 | void mlir::collapseParallelLoops( |
732 | RewriterBase &rewriter, scf::ParallelOp loops, |
733 | ArrayRef<std::vector<unsigned>> combinedDimensions) { |
734 | OpBuilder::InsertionGuard g(rewriter); |
735 | rewriter.setInsertionPoint(loops); |
736 | Location loc = loops.getLoc(); |
737 | |
738 | // Presort combined dimensions. |
739 | auto sortedDimensions = llvm::to_vector<3>(Range&: combinedDimensions); |
740 | for (auto &dims : sortedDimensions) |
741 | llvm::sort(C&: dims); |
742 | |
743 | // Normalize ParallelOp's iteration pattern. |
744 | SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps, |
745 | normalizedUpperBounds; |
746 | for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) { |
747 | OpBuilder::InsertionGuard g2(rewriter); |
748 | rewriter.setInsertionPoint(loops); |
749 | Value lb = loops.getLowerBound()[i]; |
750 | Value ub = loops.getUpperBound()[i]; |
751 | Value step = loops.getStep()[i]; |
752 | auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step); |
753 | normalizedLowerBounds.push_back(Elt: newLoopParams.lowerBound); |
754 | normalizedUpperBounds.push_back(Elt: newLoopParams.upperBound); |
755 | normalizedSteps.push_back(Elt: newLoopParams.step); |
756 | |
757 | rewriter.setInsertionPointToStart(loops.getBody()); |
758 | denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb, |
759 | step); |
760 | } |
761 | |
762 | // Combine iteration spaces. |
763 | SmallVector<Value, 3> lowerBounds, upperBounds, steps; |
764 | auto cst0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
765 | auto cst1 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
766 | for (auto &sortedDimension : sortedDimensions) { |
767 | Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
768 | for (auto idx : sortedDimension) { |
769 | newUpperBound = rewriter.create<arith::MulIOp>( |
770 | loc, newUpperBound, normalizedUpperBounds[idx]); |
771 | } |
772 | lowerBounds.push_back(Elt: cst0); |
773 | steps.push_back(Elt: cst1); |
774 | upperBounds.push_back(Elt: newUpperBound); |
775 | } |
776 | |
777 | // Create new ParallelLoop with conversions to the original induction values. |
778 | // The loop below uses divisions to get the relevant range of values in the |
779 | // new induction value that represent each range of the original induction |
780 | // value. The remainders then determine based on that range, which iteration |
781 | // of the original induction value this represents. This is a normalized value |
782 | // that is un-normalized already by the previous logic. |
783 | auto newPloop = rewriter.create<scf::ParallelOp>( |
784 | loc, lowerBounds, upperBounds, steps, |
785 | [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) { |
786 | for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { |
787 | Value previous = ploopIVs[i]; |
788 | unsigned numberCombinedDimensions = combinedDimensions[i].size(); |
789 | // Iterate over all except the last induction value. |
790 | for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) { |
791 | unsigned idx = combinedDimensions[i][j]; |
792 | |
793 | // Determine the current induction value's current loop iteration |
794 | Value iv = insideBuilder.create<arith::RemSIOp>( |
795 | loc, previous, normalizedUpperBounds[idx]); |
796 | replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv, |
797 | loops.getRegion()); |
798 | |
799 | // Remove the effect of the current induction value to prepare for |
800 | // the next value. |
801 | previous = insideBuilder.create<arith::DivSIOp>( |
802 | loc, previous, normalizedUpperBounds[idx]); |
803 | } |
804 | |
805 | // The final induction value is just the remaining value. |
806 | unsigned idx = combinedDimensions[i][0]; |
807 | replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), |
808 | previous, loops.getRegion()); |
809 | } |
810 | }); |
811 | |
812 | // Replace the old loop with the new loop. |
813 | loops.getBody()->back().erase(); |
814 | newPloop.getBody()->getOperations().splice( |
815 | Block::iterator(newPloop.getBody()->back()), |
816 | loops.getBody()->getOperations()); |
817 | loops.erase(); |
818 | } |
819 | |
820 | // Hoist the ops within `outer` that appear before `inner`. |
821 | // Such ops include the ops that have been introduced by parametric tiling. |
822 | // Ops that come from triangular loops (i.e. that belong to the program slice |
823 | // rooted at `outer`) and ops that have side effects cannot be hoisted. |
824 | // Return failure when any op fails to hoist. |
825 | static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) { |
826 | SetVector<Operation *> forwardSlice; |
827 | ForwardSliceOptions options; |
828 | options.filter = [&inner](Operation *op) { |
829 | return op != inner.getOperation(); |
830 | }; |
831 | getForwardSlice(outer.getInductionVar(), &forwardSlice, options); |
832 | LogicalResult status = success(); |
833 | SmallVector<Operation *, 8> toHoist; |
834 | for (auto &op : outer.getBody()->without_terminator()) { |
835 | // Stop when encountering the inner loop. |
836 | if (&op == inner.getOperation()) |
837 | break; |
838 | // Skip over non-hoistable ops. |
839 | if (forwardSlice.count(&op) > 0) { |
840 | status = failure(); |
841 | continue; |
842 | } |
843 | // Skip intermediate scf::ForOp, these are not considered a failure. |
844 | if (isa<scf::ForOp>(op)) |
845 | continue; |
846 | // Skip other ops with regions. |
847 | if (op.getNumRegions() > 0) { |
848 | status = failure(); |
849 | continue; |
850 | } |
851 | // Skip if op has side effects. |
852 | // TODO: loads to immutable memory regions are ok. |
853 | if (!isMemoryEffectFree(&op)) { |
854 | status = failure(); |
855 | continue; |
856 | } |
857 | toHoist.push_back(&op); |
858 | } |
859 | auto *outerForOp = outer.getOperation(); |
860 | for (auto *op : toHoist) |
861 | op->moveBefore(outerForOp); |
862 | return status; |
863 | } |
864 | |
865 | // Traverse the interTile and intraTile loops and try to hoist ops such that |
866 | // bands of perfectly nested loops are isolated. |
867 | // Return failure if either perfect interTile or perfect intraTile bands cannot |
868 | // be formed. |
869 | static LogicalResult tryIsolateBands(const TileLoops &tileLoops) { |
870 | LogicalResult status = success(); |
871 | const Loops &interTile = tileLoops.first; |
872 | const Loops &intraTile = tileLoops.second; |
873 | auto size = interTile.size(); |
874 | assert(size == intraTile.size()); |
875 | if (size <= 1) |
876 | return success(); |
877 | for (unsigned s = 1; s < size; ++s) |
878 | status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s]) |
879 | : failure(); |
880 | for (unsigned s = 1; s < size; ++s) |
881 | status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s]) |
882 | : failure(); |
883 | return status; |
884 | } |
885 | |
886 | /// Collect perfectly nested loops starting from `rootForOps`. Loops are |
887 | /// perfectly nested if each loop is the first and only non-terminator operation |
888 | /// in the parent loop. Collect at most `maxLoops` loops and append them to |
889 | /// `forOps`. |
890 | template <typename T> |
891 | static void getPerfectlyNestedLoopsImpl( |
892 | SmallVectorImpl<T> &forOps, T rootForOp, |
893 | unsigned maxLoops = std::numeric_limits<unsigned>::max()) { |
894 | for (unsigned i = 0; i < maxLoops; ++i) { |
895 | forOps.push_back(rootForOp); |
896 | Block &body = rootForOp.getRegion().front(); |
897 | if (body.begin() != std::prev(x: body.end(), n: 2)) |
898 | return; |
899 | |
900 | rootForOp = dyn_cast<T>(&body.front()); |
901 | if (!rootForOp) |
902 | return; |
903 | } |
904 | } |
905 | |
906 | static Loops stripmineSink(scf::ForOp forOp, Value factor, |
907 | ArrayRef<scf::ForOp> targets) { |
908 | auto originalStep = forOp.getStep(); |
909 | auto iv = forOp.getInductionVar(); |
910 | |
911 | OpBuilder b(forOp); |
912 | forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor)); |
913 | |
914 | Loops innerLoops; |
915 | for (auto t : targets) { |
916 | // Save information for splicing ops out of t when done |
917 | auto begin = t.getBody()->begin(); |
918 | auto nOps = t.getBody()->getOperations().size(); |
919 | |
920 | // Insert newForOp before the terminator of `t`. |
921 | auto b = OpBuilder::atBlockTerminator((t.getBody())); |
922 | Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep()); |
923 | Value ub = |
924 | b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped); |
925 | |
926 | // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. |
927 | auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep); |
928 | newForOp.getBody()->getOperations().splice( |
929 | newForOp.getBody()->getOperations().begin(), |
930 | t.getBody()->getOperations(), begin, std::next(begin, nOps - 1)); |
931 | replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(), |
932 | newForOp.getRegion()); |
933 | |
934 | innerLoops.push_back(newForOp); |
935 | } |
936 | |
937 | return innerLoops; |
938 | } |
939 | |
940 | // Stripmines a `forOp` by `factor` and sinks it under a single `target`. |
941 | // Returns the new for operation, nested immediately under `target`. |
942 | template <typename SizeType> |
943 | static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor, |
944 | scf::ForOp target) { |
945 | // TODO: Use cheap structural assertions that targets are nested under |
946 | // forOp and that targets are not nested under each other when DominanceInfo |
947 | // exposes the capability. It seems overkill to construct a whole function |
948 | // dominance tree at this point. |
949 | auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target)); |
950 | assert(res.size() == 1 && "Expected 1 inner forOp" ); |
951 | return res[0]; |
952 | } |
953 | |
954 | SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps, |
955 | ArrayRef<Value> sizes, |
956 | ArrayRef<scf::ForOp> targets) { |
957 | SmallVector<SmallVector<scf::ForOp, 8>, 8> res; |
958 | SmallVector<scf::ForOp, 8> currentTargets(targets.begin(), targets.end()); |
959 | for (auto it : llvm::zip(forOps, sizes)) { |
960 | auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets); |
961 | res.push_back(step); |
962 | currentTargets = step; |
963 | } |
964 | return res; |
965 | } |
966 | |
967 | Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes, |
968 | scf::ForOp target) { |
969 | SmallVector<scf::ForOp, 8> res; |
970 | for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) { |
971 | assert(loops.size() == 1); |
972 | res.push_back(loops[0]); |
973 | } |
974 | return res; |
975 | } |
976 | |
977 | Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) { |
978 | // Collect perfectly nested loops. If more size values provided than nested |
979 | // loops available, truncate `sizes`. |
980 | SmallVector<scf::ForOp, 4> forOps; |
981 | forOps.reserve(sizes.size()); |
982 | getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); |
983 | if (forOps.size() < sizes.size()) |
984 | sizes = sizes.take_front(N: forOps.size()); |
985 | |
986 | return ::tile(forOps, sizes, forOps.back()); |
987 | } |
988 | |
989 | void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops, |
990 | scf::ForOp root) { |
991 | getPerfectlyNestedLoopsImpl(nestedLoops, root); |
992 | } |
993 | |
994 | TileLoops mlir::(scf::ForOp rootForOp, |
995 | ArrayRef<int64_t> sizes) { |
996 | // Collect perfectly nested loops. If more size values provided than nested |
997 | // loops available, truncate `sizes`. |
998 | SmallVector<scf::ForOp, 4> forOps; |
999 | forOps.reserve(sizes.size()); |
1000 | getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); |
1001 | if (forOps.size() < sizes.size()) |
1002 | sizes = sizes.take_front(N: forOps.size()); |
1003 | |
1004 | // Compute the tile sizes such that i-th outer loop executes size[i] |
1005 | // iterations. Given that the loop current executes |
1006 | // numIterations = ceildiv((upperBound - lowerBound), step) |
1007 | // iterations, we need to tile with size ceildiv(numIterations, size[i]). |
1008 | SmallVector<Value, 4> tileSizes; |
1009 | tileSizes.reserve(N: sizes.size()); |
1010 | for (unsigned i = 0, e = sizes.size(); i < e; ++i) { |
1011 | assert(sizes[i] > 0 && "expected strictly positive size for strip-mining" ); |
1012 | |
1013 | auto forOp = forOps[i]; |
1014 | OpBuilder builder(forOp); |
1015 | auto loc = forOp.getLoc(); |
1016 | Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(), |
1017 | forOp.getLowerBound()); |
1018 | Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep()); |
1019 | Value iterationsPerBlock = |
1020 | ceilDivPositive(builder, loc, numIterations, sizes[i]); |
1021 | tileSizes.push_back(Elt: iterationsPerBlock); |
1022 | } |
1023 | |
1024 | // Call parametric tiling with the given sizes. |
1025 | auto intraTile = tile(forOps, tileSizes, forOps.back()); |
1026 | TileLoops tileLoops = std::make_pair(forOps, intraTile); |
1027 | |
1028 | // TODO: for now we just ignore the result of band isolation. |
1029 | // In the future, mapping decisions may be impacted by the ability to |
1030 | // isolate perfectly nested bands. |
1031 | (void)tryIsolateBands(tileLoops); |
1032 | |
1033 | return tileLoops; |
1034 | } |
1035 | |
1036 | scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, |
1037 | scf::ForallOp source, |
1038 | RewriterBase &rewriter) { |
1039 | unsigned numTargetOuts = target.getNumResults(); |
1040 | unsigned numSourceOuts = source.getNumResults(); |
1041 | |
1042 | // Create fused shared_outs. |
1043 | SmallVector<Value> fusedOuts; |
1044 | llvm::append_range(fusedOuts, target.getOutputs()); |
1045 | llvm::append_range(fusedOuts, source.getOutputs()); |
1046 | |
1047 | // Create a new scf.forall op after the source loop. |
1048 | rewriter.setInsertionPointAfter(source); |
1049 | scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>( |
1050 | source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), |
1051 | source.getMixedStep(), fusedOuts, source.getMapping()); |
1052 | |
1053 | // Map control operands. |
1054 | IRMapping mapping; |
1055 | mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); |
1056 | mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); |
1057 | |
1058 | // Map shared outs. |
1059 | mapping.map(target.getRegionIterArgs(), |
1060 | fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); |
1061 | mapping.map(source.getRegionIterArgs(), |
1062 | fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); |
1063 | |
1064 | // Append everything except the terminator into the fused operation. |
1065 | rewriter.setInsertionPointToStart(fusedLoop.getBody()); |
1066 | for (Operation &op : target.getBody()->without_terminator()) |
1067 | rewriter.clone(op, mapping); |
1068 | for (Operation &op : source.getBody()->without_terminator()) |
1069 | rewriter.clone(op, mapping); |
1070 | |
1071 | // Fuse the old terminator in_parallel ops into the new one. |
1072 | scf::InParallelOp targetTerm = target.getTerminator(); |
1073 | scf::InParallelOp sourceTerm = source.getTerminator(); |
1074 | scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); |
1075 | rewriter.setInsertionPointToStart(fusedTerm.getBody()); |
1076 | for (Operation &op : targetTerm.getYieldingOps()) |
1077 | rewriter.clone(op, mapping); |
1078 | for (Operation &op : sourceTerm.getYieldingOps()) |
1079 | rewriter.clone(op, mapping); |
1080 | |
1081 | // Replace old loops by substituting their uses by results of the fused loop. |
1082 | rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); |
1083 | rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); |
1084 | |
1085 | return fusedLoop; |
1086 | } |
1087 | |
1088 | scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, |
1089 | scf::ForOp source, |
1090 | RewriterBase &rewriter) { |
1091 | unsigned numTargetOuts = target.getNumResults(); |
1092 | unsigned numSourceOuts = source.getNumResults(); |
1093 | |
1094 | // Create fused init_args, with target's init_args before source's init_args. |
1095 | SmallVector<Value> fusedInitArgs; |
1096 | llvm::append_range(fusedInitArgs, target.getInitArgs()); |
1097 | llvm::append_range(fusedInitArgs, source.getInitArgs()); |
1098 | |
1099 | // Create a new scf.for op after the source loop (with scf.yield terminator |
1100 | // (without arguments) only in case its init_args is empty). |
1101 | rewriter.setInsertionPointAfter(source); |
1102 | scf::ForOp fusedLoop = rewriter.create<scf::ForOp>( |
1103 | source.getLoc(), source.getLowerBound(), source.getUpperBound(), |
1104 | source.getStep(), fusedInitArgs); |
1105 | |
1106 | // Map original induction variables and operands to those of the fused loop. |
1107 | IRMapping mapping; |
1108 | mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); |
1109 | mapping.map(target.getRegionIterArgs(), |
1110 | fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); |
1111 | mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); |
1112 | mapping.map(source.getRegionIterArgs(), |
1113 | fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); |
1114 | |
1115 | // Merge target's body into the new (fused) for loop and then source's body. |
1116 | rewriter.setInsertionPointToStart(fusedLoop.getBody()); |
1117 | for (Operation &op : target.getBody()->without_terminator()) |
1118 | rewriter.clone(op, mapping); |
1119 | for (Operation &op : source.getBody()->without_terminator()) |
1120 | rewriter.clone(op, mapping); |
1121 | |
1122 | // Build fused yield results by appropriately mapping original yield operands. |
1123 | SmallVector<Value> yieldResults; |
1124 | for (Value operand : target.getBody()->getTerminator()->getOperands()) |
1125 | yieldResults.push_back(mapping.lookupOrDefault(operand)); |
1126 | for (Value operand : source.getBody()->getTerminator()->getOperands()) |
1127 | yieldResults.push_back(mapping.lookupOrDefault(operand)); |
1128 | if (!yieldResults.empty()) |
1129 | rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults); |
1130 | |
1131 | // Replace old loops by substituting their uses by results of the fused loop. |
1132 | rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); |
1133 | rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); |
1134 | |
1135 | return fusedLoop; |
1136 | } |
1137 | |