1 | //===- Utils.cpp ---- Misc utilities for loop transformation ----------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements miscellaneous loop transformation routines. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
14 | #include "mlir/Analysis/SliceAnalysis.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/IR/BuiltinOps.h" |
21 | #include "mlir/IR/IRMapping.h" |
22 | #include "mlir/IR/OpDefinition.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
25 | #include "mlir/Transforms/RegionUtils.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/ADT/SetVector.h" |
28 | #include "llvm/ADT/SmallPtrSet.h" |
29 | #include "llvm/ADT/SmallVector.h" |
30 | #include "llvm/Support/Debug.h" |
31 | #include "llvm/Support/MathExtras.h" |
32 | #include <cstdint> |
33 | |
34 | using namespace mlir; |
35 | |
36 | #define DEBUG_TYPE "scf-utils" |
37 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
38 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
39 | |
40 | SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields( |
41 | RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest, |
42 | ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, |
43 | bool replaceIterOperandsUsesInLoop) { |
44 | if (loopNest.empty()) |
45 | return {}; |
46 | // This method is recursive (to make it more readable). Adding an |
47 | // assertion here to limit the recursion. (See |
48 | // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235) |
49 | assert(loopNest.size() <= 10 && |
50 | "exceeded recursion limit when yielding value from loop nest"); |
51 | |
52 | // To yield a value from a perfectly nested loop nest, the following |
53 | // pattern needs to be created, i.e. starting with |
54 | // |
55 | // ```mlir |
56 | // scf.for .. { |
57 | // scf.for .. { |
58 | // scf.for .. { |
59 | // %value = ... |
60 | // } |
61 | // } |
62 | // } |
63 | // ``` |
64 | // |
65 | // needs to be modified to |
66 | // |
67 | // ```mlir |
68 | // %0 = scf.for .. iter_args(%arg0 = %init) { |
69 | // %1 = scf.for .. iter_args(%arg1 = %arg0) { |
70 | // %2 = scf.for .. iter_args(%arg2 = %arg1) { |
71 | // %value = ... |
72 | // scf.yield %value |
73 | // } |
74 | // scf.yield %2 |
75 | // } |
76 | // scf.yield %1 |
77 | // } |
78 | // ``` |
79 | // |
80 | // The inner most loop is handled using the `replaceWithAdditionalYields` |
81 | // that works on a single loop. |
82 | if (loopNest.size() == 1) { |
83 | auto innerMostLoop = |
84 | cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields( |
85 | rewriter, newIterOperands, replaceIterOperandsUsesInLoop, |
86 | newYieldValuesFn)); |
87 | return {innerMostLoop}; |
88 | } |
89 | // The outer loops are modified by calling this method recursively |
90 | // - The return value of the inner loop is the value yielded by this loop. |
91 | // - The region iter args of this loop are the init_args for the inner loop. |
92 | SmallVector<scf::ForOp> newLoopNest; |
93 | NewYieldValuesFn fn = |
94 | [&](OpBuilder &innerBuilder, Location loc, |
95 | ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> { |
96 | newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(), |
97 | innerNewBBArgs, newYieldValuesFn, |
98 | replaceIterOperandsUsesInLoop); |
99 | return llvm::to_vector(llvm::map_range( |
100 | newLoopNest.front().getResults().take_back(innerNewBBArgs.size()), |
101 | [](OpResult r) -> Value { return r; })); |
102 | }; |
103 | scf::ForOp outerMostLoop = |
104 | cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields( |
105 | rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn)); |
106 | newLoopNest.insert(newLoopNest.begin(), outerMostLoop); |
107 | return newLoopNest; |
108 | } |
109 | |
110 | /// Outline a region with a single block into a new FuncOp. |
111 | /// Assumes the FuncOp result types is the type of the yielded operands of the |
112 | /// single block. This constraint makes it easy to determine the result. |
113 | /// This method also clones the `arith::ConstantIndexOp` at the start of |
114 | /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is |
115 | /// provided, it will be set to point to the operation that calls the outlined |
116 | /// function. |
117 | // TODO: support more than single-block regions. |
118 | // TODO: more flexible constant handling. |
119 | FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter, |
120 | Location loc, |
121 | Region ®ion, |
122 | StringRef funcName, |
123 | func::CallOp *callOp) { |
124 | assert(!funcName.empty() && "funcName cannot be empty"); |
125 | if (!region.hasOneBlock()) |
126 | return failure(); |
127 | |
128 | Block *originalBlock = ®ion.front(); |
129 | Operation *originalTerminator = originalBlock->getTerminator(); |
130 | |
131 | // Outline before current function. |
132 | OpBuilder::InsertionGuard g(rewriter); |
133 | rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>()); |
134 | |
135 | SetVector<Value> captures; |
136 | getUsedValuesDefinedAbove(regions: region, values&: captures); |
137 | |
138 | ValueRange outlinedValues(captures.getArrayRef()); |
139 | SmallVector<Type> outlinedFuncArgTypes; |
140 | SmallVector<Location> outlinedFuncArgLocs; |
141 | // Region's arguments are exactly the first block's arguments as per |
142 | // Region::getArguments(). |
143 | // Func's arguments are cat(regions's arguments, captures arguments). |
144 | for (BlockArgument arg : region.getArguments()) { |
145 | outlinedFuncArgTypes.push_back(Elt: arg.getType()); |
146 | outlinedFuncArgLocs.push_back(Elt: arg.getLoc()); |
147 | } |
148 | for (Value value : outlinedValues) { |
149 | outlinedFuncArgTypes.push_back(Elt: value.getType()); |
150 | outlinedFuncArgLocs.push_back(Elt: value.getLoc()); |
151 | } |
152 | FunctionType outlinedFuncType = |
153 | FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes, |
154 | originalTerminator->getOperandTypes()); |
155 | auto outlinedFunc = |
156 | rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType); |
157 | Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); |
158 | |
159 | // Merge blocks while replacing the original block operands. |
160 | // Warning: `mergeBlocks` erases the original block, reconstruct it later. |
161 | int64_t numOriginalBlockArguments = originalBlock->getNumArguments(); |
162 | auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments(); |
163 | { |
164 | OpBuilder::InsertionGuard g(rewriter); |
165 | rewriter.setInsertionPointToEnd(outlinedFuncBody); |
166 | rewriter.mergeBlocks( |
167 | source: originalBlock, dest: outlinedFuncBody, |
168 | argValues: outlinedFuncBlockArgs.take_front(numOriginalBlockArguments)); |
169 | // Explicitly set up a new ReturnOp terminator. |
170 | rewriter.setInsertionPointToEnd(outlinedFuncBody); |
171 | rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(), |
172 | originalTerminator->getOperands()); |
173 | } |
174 | |
175 | // Reconstruct the block that was deleted and add a |
176 | // terminator(call_results). |
177 | Block *newBlock = rewriter.createBlock( |
178 | parent: ®ion, insertPt: region.begin(), |
179 | argTypes: TypeRange{outlinedFuncArgTypes}.take_front(n: numOriginalBlockArguments), |
180 | locs: ArrayRef<Location>(outlinedFuncArgLocs) |
181 | .take_front(N: numOriginalBlockArguments)); |
182 | { |
183 | OpBuilder::InsertionGuard g(rewriter); |
184 | rewriter.setInsertionPointToEnd(newBlock); |
185 | SmallVector<Value> callValues; |
186 | llvm::append_range(C&: callValues, R: newBlock->getArguments()); |
187 | llvm::append_range(C&: callValues, R&: outlinedValues); |
188 | auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues); |
189 | if (callOp) |
190 | *callOp = call; |
191 | |
192 | // `originalTerminator` was moved to `outlinedFuncBody` and is still valid. |
193 | // Clone `originalTerminator` to take the callOp results then erase it from |
194 | // `outlinedFuncBody`. |
195 | IRMapping bvm; |
196 | bvm.map(originalTerminator->getOperands(), call->getResults()); |
197 | rewriter.clone(op&: *originalTerminator, mapper&: bvm); |
198 | rewriter.eraseOp(op: originalTerminator); |
199 | } |
200 | |
201 | // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`. |
202 | // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`. |
203 | for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back( |
204 | outlinedValues.size()))) { |
205 | Value orig = std::get<0>(it); |
206 | Value repl = std::get<1>(it); |
207 | { |
208 | OpBuilder::InsertionGuard g(rewriter); |
209 | rewriter.setInsertionPointToStart(outlinedFuncBody); |
210 | if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) { |
211 | IRMapping bvm; |
212 | repl = rewriter.clone(*cst, bvm)->getResult(0); |
213 | } |
214 | } |
215 | orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) { |
216 | return outlinedFunc->isProperAncestor(opOperand.getOwner()); |
217 | }); |
218 | } |
219 | |
220 | return outlinedFunc; |
221 | } |
222 | |
223 | LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp, |
224 | func::FuncOp *thenFn, StringRef thenFnName, |
225 | func::FuncOp *elseFn, StringRef elseFnName) { |
226 | IRRewriter rewriter(b); |
227 | Location loc = ifOp.getLoc(); |
228 | FailureOr<func::FuncOp> outlinedFuncOpOrFailure; |
229 | if (thenFn && !ifOp.getThenRegion().empty()) { |
230 | outlinedFuncOpOrFailure = outlineSingleBlockRegion( |
231 | rewriter, loc, ifOp.getThenRegion(), thenFnName); |
232 | if (failed(Result: outlinedFuncOpOrFailure)) |
233 | return failure(); |
234 | *thenFn = *outlinedFuncOpOrFailure; |
235 | } |
236 | if (elseFn && !ifOp.getElseRegion().empty()) { |
237 | outlinedFuncOpOrFailure = outlineSingleBlockRegion( |
238 | rewriter, loc, ifOp.getElseRegion(), elseFnName); |
239 | if (failed(Result: outlinedFuncOpOrFailure)) |
240 | return failure(); |
241 | *elseFn = *outlinedFuncOpOrFailure; |
242 | } |
243 | return success(); |
244 | } |
245 | |
246 | bool mlir::getInnermostParallelLoops(Operation *rootOp, |
247 | SmallVectorImpl<scf::ParallelOp> &result) { |
248 | assert(rootOp != nullptr && "Root operation must not be a nullptr."); |
249 | bool rootEnclosesPloops = false; |
250 | for (Region ®ion : rootOp->getRegions()) { |
251 | for (Block &block : region.getBlocks()) { |
252 | for (Operation &op : block) { |
253 | bool enclosesPloops = getInnermostParallelLoops(&op, result); |
254 | rootEnclosesPloops |= enclosesPloops; |
255 | if (auto ploop = dyn_cast<scf::ParallelOp>(op)) { |
256 | rootEnclosesPloops = true; |
257 | |
258 | // Collect parallel loop if it is an innermost one. |
259 | if (!enclosesPloops) |
260 | result.push_back(ploop); |
261 | } |
262 | } |
263 | } |
264 | } |
265 | return rootEnclosesPloops; |
266 | } |
267 | |
268 | // Build the IR that performs ceil division of a positive value by a constant: |
269 | // ceildiv(a, B) = divis(a + (B-1), B) |
270 | // where divis is rounding-to-zero division. |
271 | static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, |
272 | int64_t divisor) { |
273 | assert(divisor > 0 && "expected positive divisor"); |
274 | assert(dividend.getType().isIntOrIndex() && |
275 | "expected integer or index-typed value"); |
276 | |
277 | Value divisorMinusOneCst = builder.create<arith::ConstantOp>( |
278 | loc, builder.getIntegerAttr(dividend.getType(), divisor - 1)); |
279 | Value divisorCst = builder.create<arith::ConstantOp>( |
280 | loc, builder.getIntegerAttr(dividend.getType(), divisor)); |
281 | Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst); |
282 | return builder.create<arith::DivUIOp>(loc, sum, divisorCst); |
283 | } |
284 | |
285 | // Build the IR that performs ceil division of a positive value by another |
286 | // positive value: |
287 | // ceildiv(a, b) = divis(a + (b - 1), b) |
288 | // where divis is rounding-to-zero division. |
289 | static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, |
290 | Value divisor) { |
291 | assert(dividend.getType().isIntOrIndex() && |
292 | "expected integer or index-typed value"); |
293 | Value cstOne = builder.create<arith::ConstantOp>( |
294 | loc, builder.getOneAttr(dividend.getType())); |
295 | Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne); |
296 | Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne); |
297 | return builder.create<arith::DivUIOp>(loc, sum, divisor); |
298 | } |
299 | |
300 | /// Returns the trip count of `forOp` if its' low bound, high bound and step are |
301 | /// constants, or optional otherwise. Trip count is computed as |
302 | /// ceilDiv(highBound - lowBound, step). |
303 | static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) { |
304 | std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound()); |
305 | std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound()); |
306 | std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep()); |
307 | if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value()) |
308 | return {}; |
309 | |
310 | // Constant loop bounds computation. |
311 | int64_t lbCst = lbCstOp.value(); |
312 | int64_t ubCst = ubCstOp.value(); |
313 | int64_t stepCst = stepCstOp.value(); |
314 | assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 && |
315 | "expected positive loop bounds and step"); |
316 | return llvm::divideCeilSigned(Numerator: ubCst - lbCst, Denominator: stepCst); |
317 | } |
318 | |
319 | /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with |
320 | /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap |
321 | /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each |
322 | /// unrolled iteration using annotateFn. |
323 | static void generateUnrolledLoop( |
324 | Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, |
325 | function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn, |
326 | function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn, |
327 | ValueRange iterArgs, ValueRange yieldedValues) { |
328 | // Builder to insert unrolled bodies just before the terminator of the body of |
329 | // 'forOp'. |
330 | auto builder = OpBuilder::atBlockTerminator(block: loopBodyBlock); |
331 | |
332 | constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; |
333 | if (!annotateFn) |
334 | annotateFn = defaultAnnotateFn; |
335 | |
336 | // Keep a pointer to the last non-terminator operation in the original block |
337 | // so that we know what to clone (since we are doing this in-place). |
338 | Block::iterator srcBlockEnd = std::prev(x: loopBodyBlock->end(), n: 2); |
339 | |
340 | // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). |
341 | SmallVector<Value, 4> lastYielded(yieldedValues); |
342 | |
343 | for (unsigned i = 1; i < unrollFactor; i++) { |
344 | IRMapping operandMap; |
345 | |
346 | // Prepare operand map. |
347 | operandMap.map(from&: iterArgs, to&: lastYielded); |
348 | |
349 | // If the induction variable is used, create a remapping to the value for |
350 | // this unrolled instance. |
351 | if (!forOpIV.use_empty()) { |
352 | Value ivUnroll = ivRemapFn(i, forOpIV, builder); |
353 | operandMap.map(from: forOpIV, to: ivUnroll); |
354 | } |
355 | |
356 | // Clone the original body of 'forOp'. |
357 | for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++) { |
358 | Operation *clonedOp = builder.clone(op&: *it, mapper&: operandMap); |
359 | annotateFn(i, clonedOp, builder); |
360 | } |
361 | |
362 | // Update yielded values. |
363 | for (unsigned i = 0, e = lastYielded.size(); i < e; i++) |
364 | lastYielded[i] = operandMap.lookupOrDefault(from: yieldedValues[i]); |
365 | } |
366 | |
367 | // Make sure we annotate the Ops in the original body. We do this last so that |
368 | // any annotations are not copied into the cloned Ops above. |
369 | for (auto it = loopBodyBlock->begin(); it != std::next(x: srcBlockEnd); it++) |
370 | annotateFn(0, &*it, builder); |
371 | |
372 | // Update operands of the yield statement. |
373 | loopBodyBlock->getTerminator()->setOperands(lastYielded); |
374 | } |
375 | |
376 | /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the |
377 | /// epilogue loop, if the loop is unrolled. |
378 | FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor( |
379 | scf::ForOp forOp, uint64_t unrollFactor, |
380 | function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) { |
381 | assert(unrollFactor > 0 && "expected positive unroll factor"); |
382 | |
383 | // Return if the loop body is empty. |
384 | if (llvm::hasSingleElement(forOp.getBody()->getOperations())) |
385 | return UnrolledLoopInfo{forOp, std::nullopt}; |
386 | |
387 | // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate |
388 | // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. |
389 | OpBuilder boundsBuilder(forOp); |
390 | IRRewriter rewriter(forOp.getContext()); |
391 | auto loc = forOp.getLoc(); |
392 | Value step = forOp.getStep(); |
393 | Value upperBoundUnrolled; |
394 | Value stepUnrolled; |
395 | bool generateEpilogueLoop = true; |
396 | |
397 | std::optional<int64_t> constTripCount = getConstantTripCount(forOp); |
398 | if (constTripCount) { |
399 | // Constant loop bounds computation. |
400 | int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value(); |
401 | int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value(); |
402 | int64_t stepCst = getConstantIntValue(forOp.getStep()).value(); |
403 | if (unrollFactor == 1) { |
404 | if (*constTripCount == 1 && |
405 | failed(forOp.promoteIfSingleIteration(rewriter))) |
406 | return failure(); |
407 | return UnrolledLoopInfo{forOp, std::nullopt}; |
408 | } |
409 | |
410 | int64_t tripCountEvenMultiple = |
411 | *constTripCount - (*constTripCount % unrollFactor); |
412 | int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; |
413 | int64_t stepUnrolledCst = stepCst * unrollFactor; |
414 | |
415 | // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. |
416 | generateEpilogueLoop = upperBoundUnrolledCst < ubCst; |
417 | if (generateEpilogueLoop) |
418 | upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>( |
419 | loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(), |
420 | upperBoundUnrolledCst)); |
421 | else |
422 | upperBoundUnrolled = forOp.getUpperBound(); |
423 | |
424 | // Create constant for 'stepUnrolled'. |
425 | stepUnrolled = stepCst == stepUnrolledCst |
426 | ? step |
427 | : boundsBuilder.create<arith::ConstantOp>( |
428 | loc, boundsBuilder.getIntegerAttr( |
429 | step.getType(), stepUnrolledCst)); |
430 | } else { |
431 | // Dynamic loop bounds computation. |
432 | // TODO: Add dynamic asserts for negative lb/ub/step, or |
433 | // consider using ceilDiv from AffineApplyExpander. |
434 | auto lowerBound = forOp.getLowerBound(); |
435 | auto upperBound = forOp.getUpperBound(); |
436 | Value diff = |
437 | boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound); |
438 | Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); |
439 | Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>( |
440 | loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor)); |
441 | Value tripCountRem = |
442 | boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst); |
443 | // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) |
444 | Value tripCountEvenMultiple = |
445 | boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem); |
446 | // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step |
447 | upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>( |
448 | loc, lowerBound, |
449 | boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step)); |
450 | // Scale 'step' by 'unrollFactor'. |
451 | stepUnrolled = |
452 | boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst); |
453 | } |
454 | |
455 | UnrolledLoopInfo resultLoops; |
456 | |
457 | // Create epilogue clean up loop starting at 'upperBoundUnrolled'. |
458 | if (generateEpilogueLoop) { |
459 | OpBuilder epilogueBuilder(forOp->getContext()); |
460 | epilogueBuilder.setInsertionPointAfter(forOp); |
461 | auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp)); |
462 | epilogueForOp.setLowerBound(upperBoundUnrolled); |
463 | |
464 | // Update uses of loop results. |
465 | auto results = forOp.getResults(); |
466 | auto epilogueResults = epilogueForOp.getResults(); |
467 | |
468 | for (auto e : llvm::zip(results, epilogueResults)) { |
469 | std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); |
470 | } |
471 | epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), |
472 | epilogueForOp.getInitArgs().size(), results); |
473 | if (epilogueForOp.promoteIfSingleIteration(rewriter).failed()) |
474 | resultLoops.epilogueLoopOp = epilogueForOp; |
475 | } |
476 | |
477 | // Create unrolled loop. |
478 | forOp.setUpperBound(upperBoundUnrolled); |
479 | forOp.setStep(stepUnrolled); |
480 | |
481 | auto iterArgs = ValueRange(forOp.getRegionIterArgs()); |
482 | auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); |
483 | |
484 | generateUnrolledLoop( |
485 | forOp.getBody(), forOp.getInductionVar(), unrollFactor, |
486 | [&](unsigned i, Value iv, OpBuilder b) { |
487 | // iv' = iv + step * i; |
488 | auto stride = b.create<arith::MulIOp>( |
489 | loc, step, |
490 | b.create<arith::ConstantOp>(loc, |
491 | b.getIntegerAttr(iv.getType(), i))); |
492 | return b.create<arith::AddIOp>(loc, iv, stride); |
493 | }, |
494 | annotateFn, iterArgs, yieldedValues); |
495 | // Promote the loop body up if this has turned into a single iteration loop. |
496 | if (forOp.promoteIfSingleIteration(rewriter).failed()) |
497 | resultLoops.mainLoopOp = forOp; |
498 | return resultLoops; |
499 | } |
500 | |
501 | /// Unrolls this loop completely. |
502 | LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) { |
503 | IRRewriter rewriter(forOp.getContext()); |
504 | std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); |
505 | if (!mayBeConstantTripCount.has_value()) |
506 | return failure(); |
507 | uint64_t tripCount = *mayBeConstantTripCount; |
508 | if (tripCount == 0) |
509 | return success(); |
510 | if (tripCount == 1) |
511 | return forOp.promoteIfSingleIteration(rewriter); |
512 | return loopUnrollByFactor(forOp, tripCount); |
513 | } |
514 | |
515 | /// Check if bounds of all inner loops are defined outside of `forOp` |
516 | /// and return false if not. |
517 | static bool areInnerBoundsInvariant(scf::ForOp forOp) { |
518 | auto walkResult = forOp.walk([&](scf::ForOp innerForOp) { |
519 | if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) || |
520 | !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) || |
521 | !forOp.isDefinedOutsideOfLoop(innerForOp.getStep())) |
522 | return WalkResult::interrupt(); |
523 | |
524 | return WalkResult::advance(); |
525 | }); |
526 | return !walkResult.wasInterrupted(); |
527 | } |
528 | |
529 | /// Unrolls and jams this loop by the specified factor. |
530 | LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, |
531 | uint64_t unrollJamFactor) { |
532 | assert(unrollJamFactor > 0 && "unroll jam factor should be positive"); |
533 | |
534 | if (unrollJamFactor == 1) |
535 | return success(); |
536 | |
537 | // If any control operand of any inner loop of `forOp` is defined within |
538 | // `forOp`, no unroll jam. |
539 | if (!areInnerBoundsInvariant(forOp)) { |
540 | LDBG("failed to unroll and jam: inner bounds are not invariant"); |
541 | return failure(); |
542 | } |
543 | |
544 | // Currently, for operations with results are not supported. |
545 | if (forOp->getNumResults() > 0) { |
546 | LDBG("failed to unroll and jam: unsupported loop with results"); |
547 | return failure(); |
548 | } |
549 | |
550 | // Currently, only constant trip count that divided by the unroll factor is |
551 | // supported. |
552 | std::optional<uint64_t> tripCount = getConstantTripCount(forOp); |
553 | if (!tripCount.has_value()) { |
554 | // If the trip count is dynamic, do not unroll & jam. |
555 | LDBG("failed to unroll and jam: trip count could not be determined"); |
556 | return failure(); |
557 | } |
558 | if (unrollJamFactor > *tripCount) { |
559 | LDBG("unroll and jam factor is greater than trip count, set factor to trip " |
560 | "count"); |
561 | unrollJamFactor = *tripCount; |
562 | } else if (*tripCount % unrollJamFactor != 0) { |
563 | LDBG("failed to unroll and jam: unsupported trip count that is not a " |
564 | "multiple of unroll jam factor"); |
565 | return failure(); |
566 | } |
567 | |
568 | // Nothing in the loop body other than the terminator. |
569 | if (llvm::hasSingleElement(forOp.getBody()->getOperations())) |
570 | return success(); |
571 | |
572 | // Gather all sub-blocks to jam upon the loop being unrolled. |
573 | JamBlockGatherer<scf::ForOp> jbg; |
574 | jbg.walk(forOp); |
575 | auto &subBlocks = jbg.subBlocks; |
576 | |
577 | // Collect inner loops. |
578 | SmallVector<scf::ForOp> innerLoops; |
579 | forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); }); |
580 | |
581 | // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled |
582 | // iteration. There are (`unrollJamFactor` - 1) iterations. |
583 | SmallVector<IRMapping> operandMaps(unrollJamFactor - 1); |
584 | |
585 | // For any loop with iter_args, replace it with a new loop that has |
586 | // `unrollJamFactor` copies of its iterOperands, iter_args and yield |
587 | // operands. |
588 | SmallVector<scf::ForOp> newInnerLoops; |
589 | IRRewriter rewriter(forOp.getContext()); |
590 | for (scf::ForOp oldForOp : innerLoops) { |
591 | SmallVector<Value> dupIterOperands, dupYieldOperands; |
592 | ValueRange oldIterOperands = oldForOp.getInits(); |
593 | ValueRange oldIterArgs = oldForOp.getRegionIterArgs(); |
594 | ValueRange oldYieldOperands = |
595 | cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands(); |
596 | // Get additional iterOperands, iterArgs, and yield operands. We will |
597 | // fix iterOperands and yield operands after cloning of sub-blocks. |
598 | for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { |
599 | dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end()); |
600 | dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end()); |
601 | } |
602 | // Create a new loop with additional iterOperands, iter_args and yield |
603 | // operands. This new loop will take the loop body of the original loop. |
604 | bool forOpReplaced = oldForOp == forOp; |
605 | scf::ForOp newForOp = |
606 | cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields( |
607 | rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false, |
608 | [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) { |
609 | return dupYieldOperands; |
610 | })); |
611 | newInnerLoops.push_back(newForOp); |
612 | // `forOp` has been replaced with a new loop. |
613 | if (forOpReplaced) |
614 | forOp = newForOp; |
615 | // Update `operandMaps` for `newForOp` iterArgs and results. |
616 | ValueRange newIterArgs = newForOp.getRegionIterArgs(); |
617 | unsigned oldNumIterArgs = oldIterArgs.size(); |
618 | ValueRange newResults = newForOp.getResults(); |
619 | unsigned oldNumResults = newResults.size() / unrollJamFactor; |
620 | assert(oldNumIterArgs == oldNumResults && |
621 | "oldNumIterArgs must be the same as oldNumResults"); |
622 | for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { |
623 | for (unsigned j = 0; j < oldNumIterArgs; ++j) { |
624 | // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and |
625 | // results. Update `operandMaps[i - 1]` to map old iterArgs and results |
626 | // to those in the `i`th new set. |
627 | operandMaps[i - 1].map(newIterArgs[j], |
628 | newIterArgs[i * oldNumIterArgs + j]); |
629 | operandMaps[i - 1].map(newResults[j], |
630 | newResults[i * oldNumResults + j]); |
631 | } |
632 | } |
633 | } |
634 | |
635 | // Scale the step of loop being unroll-jammed by the unroll-jam factor. |
636 | rewriter.setInsertionPoint(forOp); |
637 | int64_t step = forOp.getConstantStep()->getSExtValue(); |
638 | auto newStep = rewriter.createOrFold<arith::MulIOp>( |
639 | forOp.getLoc(), forOp.getStep(), |
640 | rewriter.createOrFold<arith::ConstantOp>( |
641 | forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor))); |
642 | forOp.setStep(newStep); |
643 | auto forOpIV = forOp.getInductionVar(); |
644 | |
645 | // Unroll and jam (appends unrollJamFactor - 1 additional copies). |
646 | for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { |
647 | for (auto &subBlock : subBlocks) { |
648 | // Builder to insert unroll-jammed bodies. Insert right at the end of |
649 | // sub-block. |
650 | OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second)); |
651 | |
652 | // If the induction variable is used, create a remapping to the value for |
653 | // this unrolled instance. |
654 | if (!forOpIV.use_empty()) { |
655 | // iv' = iv + i * step, i = 1 to unrollJamFactor-1. |
656 | auto ivTag = builder.createOrFold<arith::ConstantOp>( |
657 | forOp.getLoc(), builder.getIndexAttr(step * i)); |
658 | auto ivUnroll = |
659 | builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag); |
660 | operandMaps[i - 1].map(forOpIV, ivUnroll); |
661 | } |
662 | // Clone the sub-block being unroll-jammed. |
663 | for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) |
664 | builder.clone(*it, operandMaps[i - 1]); |
665 | } |
666 | // Fix iterOperands and yield op operands of newly created loops. |
667 | for (auto newForOp : newInnerLoops) { |
668 | unsigned oldNumIterOperands = |
669 | newForOp.getNumRegionIterArgs() / unrollJamFactor; |
670 | unsigned numControlOperands = newForOp.getNumControlOperands(); |
671 | auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator()); |
672 | unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor; |
673 | assert(oldNumIterOperands == oldNumYieldOperands && |
674 | "oldNumIterOperands must be the same as oldNumYieldOperands"); |
675 | for (unsigned j = 0; j < oldNumIterOperands; ++j) { |
676 | // The `i`th duplication of an old iterOperand or yield op operand |
677 | // needs to be replaced with a mapped value from `operandMaps[i - 1]` |
678 | // if such mapped value exists. |
679 | newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j, |
680 | operandMaps[i - 1].lookupOrDefault( |
681 | newForOp.getOperand(numControlOperands + j))); |
682 | yieldOp.setOperand( |
683 | i * oldNumYieldOperands + j, |
684 | operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j))); |
685 | } |
686 | } |
687 | } |
688 | |
689 | // Promote the loop body up if this has turned into a single iteration loop. |
690 | (void)forOp.promoteIfSingleIteration(rewriter); |
691 | return success(); |
692 | } |
693 | |
694 | Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, |
695 | OpFoldResult lb, OpFoldResult ub, |
696 | OpFoldResult step) { |
697 | Range normalizedLoopBounds; |
698 | normalizedLoopBounds.offset = rewriter.getIndexAttr(0); |
699 | normalizedLoopBounds.stride = rewriter.getIndexAttr(1); |
700 | AffineExpr s0, s1, s2; |
701 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2); |
702 | AffineExpr e = (s1 - s0).ceilDiv(other: s2); |
703 | normalizedLoopBounds.size = |
704 | affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: e, operands: {lb, ub, step}); |
705 | return normalizedLoopBounds; |
706 | } |
707 | |
708 | Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, |
709 | OpFoldResult lb, OpFoldResult ub, |
710 | OpFoldResult step) { |
711 | if (getType(ofr: lb).isIndex()) { |
712 | return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step); |
713 | } |
714 | // For non-index types, generate `arith` instructions |
715 | // Check if the loop is already known to have a constant zero lower bound or |
716 | // a constant one step. |
717 | bool isZeroBased = false; |
718 | if (auto lbCst = getConstantIntValue(ofr: lb)) |
719 | isZeroBased = lbCst.value() == 0; |
720 | |
721 | bool isStepOne = false; |
722 | if (auto stepCst = getConstantIntValue(ofr: step)) |
723 | isStepOne = stepCst.value() == 1; |
724 | |
725 | Type rangeType = getType(ofr: lb); |
726 | assert(rangeType == getType(ub) && rangeType == getType(step) && |
727 | "expected matching types"); |
728 | |
729 | // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) |
730 | // assuming the step is strictly positive. Update the bounds and the step |
731 | // of the loop to go from 0 to the number of iterations, if necessary. |
732 | if (isZeroBased && isStepOne) |
733 | return {.offset: lb, .size: ub, .stride: step}; |
734 | |
735 | OpFoldResult diff = ub; |
736 | if (!isZeroBased) { |
737 | diff = rewriter.createOrFold<arith::SubIOp>( |
738 | loc, getValueOrCreateConstantIntOp(rewriter, loc, ub), |
739 | getValueOrCreateConstantIntOp(rewriter, loc, lb)); |
740 | } |
741 | OpFoldResult newUpperBound = diff; |
742 | if (!isStepOne) { |
743 | newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>( |
744 | loc, getValueOrCreateConstantIntOp(rewriter, loc, diff), |
745 | getValueOrCreateConstantIntOp(rewriter, loc, step)); |
746 | } |
747 | |
748 | OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType); |
749 | OpFoldResult newStep = rewriter.getOneAttr(rangeType); |
750 | |
751 | return {.offset: newLowerBound, .size: newUpperBound, .stride: newStep}; |
752 | } |
753 | |
754 | static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, |
755 | Location loc, |
756 | Value normalizedIv, |
757 | OpFoldResult origLb, |
758 | OpFoldResult origStep) { |
759 | AffineExpr d0, s0, s1; |
760 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1); |
761 | bindDims(ctx: rewriter.getContext(), exprs&: d0); |
762 | AffineExpr e = d0 * s1 + s0; |
763 | OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply( |
764 | b&: rewriter, loc, expr: e, operands: ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep}); |
765 | Value denormalizedIvVal = |
766 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: denormalizedIv); |
767 | SmallPtrSet<Operation *, 1> preservedUses; |
768 | // If an `affine.apply` operation is generated for denormalization, the use |
769 | // of `origLb` in those ops must not be replaced. These arent not generated |
770 | // when `origLb == 0` and `origStep == 1`. |
771 | if (!isZeroInteger(v: origLb) || !isOneInteger(v: origStep)) { |
772 | if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) { |
773 | preservedUses.insert(Ptr: preservedUse); |
774 | } |
775 | } |
776 | rewriter.replaceAllUsesExcept(from: normalizedIv, to: denormalizedIvVal, preservedUsers: preservedUses); |
777 | } |
778 | |
779 | void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc, |
780 | Value normalizedIv, OpFoldResult origLb, |
781 | OpFoldResult origStep) { |
782 | if (getType(ofr: origLb).isIndex()) { |
783 | return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv, |
784 | origLb, origStep); |
785 | } |
786 | Value denormalizedIv; |
787 | SmallPtrSet<Operation *, 2> preserve; |
788 | bool isStepOne = isOneInteger(v: origStep); |
789 | bool isZeroBased = isZeroInteger(v: origLb); |
790 | |
791 | Value scaled = normalizedIv; |
792 | if (!isStepOne) { |
793 | Value origStepValue = |
794 | getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: origStep); |
795 | scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue); |
796 | preserve.insert(Ptr: scaled.getDefiningOp()); |
797 | } |
798 | denormalizedIv = scaled; |
799 | if (!isZeroBased) { |
800 | Value origLbValue = getValueOrCreateConstantIntOp(b&: rewriter, loc, ofr: origLb); |
801 | denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue); |
802 | preserve.insert(Ptr: denormalizedIv.getDefiningOp()); |
803 | } |
804 | |
805 | rewriter.replaceAllUsesExcept(from: normalizedIv, to: denormalizedIv, preservedUsers: preserve); |
806 | } |
807 | |
808 | static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, |
809 | ArrayRef<OpFoldResult> values) { |
810 | assert(!values.empty() && "unexecpted empty array"); |
811 | AffineExpr s0, s1; |
812 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1); |
813 | AffineExpr mul = s0 * s1; |
814 | OpFoldResult products = rewriter.getIndexAttr(1); |
815 | for (auto v : values) { |
816 | products = affine::makeComposedFoldedAffineApply( |
817 | b&: rewriter, loc, expr: mul, operands: ArrayRef<OpFoldResult>{products, v}); |
818 | } |
819 | return products; |
820 | } |
821 | |
822 | /// Helper function to multiply a sequence of values. |
823 | static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, |
824 | ArrayRef<Value> values) { |
825 | assert(!values.empty() && "unexpected empty list"); |
826 | if (getType(ofr: values.front()).isIndex()) { |
827 | SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values); |
828 | OpFoldResult product = getProductOfIndexes(rewriter, loc, values: ofrs); |
829 | return getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: product); |
830 | } |
831 | std::optional<Value> productOf; |
832 | for (auto v : values) { |
833 | auto vOne = getConstantIntValue(ofr: v); |
834 | if (vOne && vOne.value() == 1) |
835 | continue; |
836 | if (productOf) |
837 | productOf = |
838 | rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult(); |
839 | else |
840 | productOf = v; |
841 | } |
842 | if (!productOf) { |
843 | productOf = rewriter |
844 | .create<arith::ConstantOp>( |
845 | loc, rewriter.getOneAttr(getType(values.front()))) |
846 | .getResult(); |
847 | } |
848 | return productOf.value(); |
849 | } |
850 | |
851 | /// For each original loop, the value of the |
852 | /// induction variable can be obtained by dividing the induction variable of |
853 | /// the linearized loop by the total number of iterations of the loops nested |
854 | /// in it modulo the number of iterations in this loop (remove the values |
855 | /// related to the outer loops): |
856 | /// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i. |
857 | /// Compute these iteratively from the innermost loop by creating a "running |
858 | /// quotient" of division by the range. |
859 | static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>> |
860 | delinearizeInductionVariable(RewriterBase &rewriter, Location loc, |
861 | Value linearizedIv, ArrayRef<Value> ubs) { |
862 | |
863 | if (linearizedIv.getType().isIndex()) { |
864 | Operation *delinearizedOp = |
865 | rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv, |
866 | ubs); |
867 | auto resultVals = llvm::map_to_vector( |
868 | C: delinearizedOp->getResults(), F: [](OpResult r) -> Value { return r; }); |
869 | return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}}; |
870 | } |
871 | |
872 | SmallVector<Value> delinearizedIvs(ubs.size()); |
873 | SmallPtrSet<Operation *, 2> preservedUsers; |
874 | |
875 | llvm::BitVector isUbOne(ubs.size()); |
876 | for (auto [index, ub] : llvm::enumerate(First&: ubs)) { |
877 | auto ubCst = getConstantIntValue(ofr: ub); |
878 | if (ubCst && ubCst.value() == 1) |
879 | isUbOne.set(index); |
880 | } |
881 | |
882 | // Prune the lead ubs that are all ones. |
883 | unsigned numLeadingOneUbs = 0; |
884 | for (auto [index, ub] : llvm::enumerate(First&: ubs)) { |
885 | if (!isUbOne.test(Idx: index)) { |
886 | break; |
887 | } |
888 | delinearizedIvs[index] = rewriter.create<arith::ConstantOp>( |
889 | loc, rewriter.getZeroAttr(ub.getType())); |
890 | numLeadingOneUbs++; |
891 | } |
892 | |
893 | Value previous = linearizedIv; |
894 | for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) { |
895 | unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1; |
896 | if (i != numLeadingOneUbs && !isUbOne.test(Idx: idx + 1)) { |
897 | previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]); |
898 | preservedUsers.insert(Ptr: previous.getDefiningOp()); |
899 | } |
900 | Value iv = previous; |
901 | if (i != e - 1) { |
902 | if (!isUbOne.test(Idx: idx)) { |
903 | iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]); |
904 | preservedUsers.insert(Ptr: iv.getDefiningOp()); |
905 | } else { |
906 | iv = rewriter.create<arith::ConstantOp>( |
907 | loc, rewriter.getZeroAttr(ubs[idx].getType())); |
908 | } |
909 | } |
910 | delinearizedIvs[idx] = iv; |
911 | } |
912 | return {delinearizedIvs, preservedUsers}; |
913 | } |
914 | |
915 | LogicalResult mlir::coalesceLoops(RewriterBase &rewriter, |
916 | MutableArrayRef<scf::ForOp> loops) { |
917 | if (loops.size() < 2) |
918 | return failure(); |
919 | |
920 | scf::ForOp innermost = loops.back(); |
921 | scf::ForOp outermost = loops.front(); |
922 | |
923 | // 1. Make sure all loops iterate from 0 to upperBound with step 1. This |
924 | // allows the following code to assume upperBound is the number of iterations. |
925 | for (auto loop : loops) { |
926 | OpBuilder::InsertionGuard g(rewriter); |
927 | rewriter.setInsertionPoint(outermost); |
928 | Value lb = loop.getLowerBound(); |
929 | Value ub = loop.getUpperBound(); |
930 | Value step = loop.getStep(); |
931 | auto newLoopRange = |
932 | emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step); |
933 | |
934 | rewriter.modifyOpInPlace(loop, [&]() { |
935 | loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), |
936 | newLoopRange.offset)); |
937 | loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), |
938 | newLoopRange.size)); |
939 | loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), |
940 | newLoopRange.stride)); |
941 | }); |
942 | rewriter.setInsertionPointToStart(innermost.getBody()); |
943 | denormalizeInductionVariable(rewriter, loop.getLoc(), |
944 | loop.getInductionVar(), lb, step); |
945 | } |
946 | |
947 | // 2. Emit code computing the upper bound of the coalesced loop as product |
948 | // of the number of iterations of all loops. |
949 | OpBuilder::InsertionGuard g(rewriter); |
950 | rewriter.setInsertionPoint(outermost); |
951 | Location loc = outermost.getLoc(); |
952 | SmallVector<Value> upperBounds = llvm::map_to_vector( |
953 | loops, [](auto loop) { return loop.getUpperBound(); }); |
954 | Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, values: upperBounds); |
955 | outermost.setUpperBound(upperBound); |
956 | |
957 | rewriter.setInsertionPointToStart(innermost.getBody()); |
958 | auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable( |
959 | rewriter, loc, outermost.getInductionVar(), upperBounds); |
960 | rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0], |
961 | preservedUsers); |
962 | |
963 | for (int i = loops.size() - 1; i > 0; --i) { |
964 | auto outerLoop = loops[i - 1]; |
965 | auto innerLoop = loops[i]; |
966 | |
967 | Operation *innerTerminator = innerLoop.getBody()->getTerminator(); |
968 | auto yieldedVals = llvm::to_vector(Range: innerTerminator->getOperands()); |
969 | assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs())); |
970 | for (Value &yieldedVal : yieldedVals) { |
971 | // The yielded value may be an iteration argument of the inner loop |
972 | // which is about to be inlined. |
973 | auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal); |
974 | if (iter != innerLoop.getRegionIterArgs().end()) { |
975 | unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin(); |
976 | // `outerLoop` iter args identical to the `innerLoop` init args. |
977 | assert(iterArgIndex < innerLoop.getInitArgs().size()); |
978 | yieldedVal = innerLoop.getInitArgs()[iterArgIndex]; |
979 | } |
980 | } |
981 | rewriter.eraseOp(op: innerTerminator); |
982 | |
983 | SmallVector<Value> innerBlockArgs; |
984 | innerBlockArgs.push_back(Elt: delinearizeIvs[i]); |
985 | llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs()); |
986 | rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(), |
987 | Block::iterator(innerLoop), innerBlockArgs); |
988 | rewriter.replaceOp(innerLoop, yieldedVals); |
989 | } |
990 | return success(); |
991 | } |
992 | |
993 | LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) { |
994 | if (loops.empty()) { |
995 | return failure(); |
996 | } |
997 | IRRewriter rewriter(loops.front().getContext()); |
998 | return coalesceLoops(rewriter, loops); |
999 | } |
1000 | |
1001 | LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) { |
1002 | LogicalResult result(failure()); |
1003 | SmallVector<scf::ForOp> loops; |
1004 | getPerfectlyNestedLoops(loops, op); |
1005 | |
1006 | // Look for a band of loops that can be coalesced, i.e. perfectly nested |
1007 | // loops with bounds defined above some loop. |
1008 | |
1009 | // 1. For each loop, find above which parent loop its bounds operands are |
1010 | // defined. |
1011 | SmallVector<unsigned> operandsDefinedAbove(loops.size()); |
1012 | for (unsigned i = 0, e = loops.size(); i < e; ++i) { |
1013 | operandsDefinedAbove[i] = i; |
1014 | for (unsigned j = 0; j < i; ++j) { |
1015 | SmallVector<Value> boundsOperands = {loops[i].getLowerBound(), |
1016 | loops[i].getUpperBound(), |
1017 | loops[i].getStep()}; |
1018 | if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) { |
1019 | operandsDefinedAbove[i] = j; |
1020 | break; |
1021 | } |
1022 | } |
1023 | } |
1024 | |
1025 | // 2. For each inner loop check that the iter_args for the immediately outer |
1026 | // loop are the init for the immediately inner loop and that the yields of the |
1027 | // return of the inner loop is the yield for the immediately outer loop. Keep |
1028 | // track of where the chain starts from for each loop. |
1029 | SmallVector<unsigned> iterArgChainStart(loops.size()); |
1030 | iterArgChainStart[0] = 0; |
1031 | for (unsigned i = 1, e = loops.size(); i < e; ++i) { |
1032 | // By default set the start of the chain to itself. |
1033 | iterArgChainStart[i] = i; |
1034 | auto outerloop = loops[i - 1]; |
1035 | auto innerLoop = loops[i]; |
1036 | if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) { |
1037 | continue; |
1038 | } |
1039 | if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) { |
1040 | continue; |
1041 | } |
1042 | auto outerloopTerminator = outerloop.getBody()->getTerminator(); |
1043 | if (!llvm::equal(outerloopTerminator->getOperands(), |
1044 | innerLoop.getResults())) { |
1045 | continue; |
1046 | } |
1047 | iterArgChainStart[i] = iterArgChainStart[i - 1]; |
1048 | } |
1049 | |
1050 | // 3. Identify bands of loops such that the operands of all of them are |
1051 | // defined above the first loop in the band. Traverse the nest bottom-up |
1052 | // so that modifications don't invalidate the inner loops. |
1053 | for (unsigned end = loops.size(); end > 0; --end) { |
1054 | unsigned start = 0; |
1055 | for (; start < end - 1; ++start) { |
1056 | auto maxPos = |
1057 | *std::max_element(first: std::next(x: operandsDefinedAbove.begin(), n: start), |
1058 | last: std::next(x: operandsDefinedAbove.begin(), n: end)); |
1059 | if (maxPos > start) |
1060 | continue; |
1061 | if (iterArgChainStart[end - 1] > start) |
1062 | continue; |
1063 | auto band = llvm::MutableArrayRef(loops.data() + start, end - start); |
1064 | if (succeeded(coalesceLoops(band))) |
1065 | result = success(); |
1066 | break; |
1067 | } |
1068 | // If a band was found and transformed, keep looking at the loops above |
1069 | // the outermost transformed loop. |
1070 | if (start != end - 1) |
1071 | end = start + 1; |
1072 | } |
1073 | return result; |
1074 | } |
1075 | |
1076 | void mlir::collapseParallelLoops( |
1077 | RewriterBase &rewriter, scf::ParallelOp loops, |
1078 | ArrayRef<std::vector<unsigned>> combinedDimensions) { |
1079 | OpBuilder::InsertionGuard g(rewriter); |
1080 | rewriter.setInsertionPoint(loops); |
1081 | Location loc = loops.getLoc(); |
1082 | |
1083 | // Presort combined dimensions. |
1084 | auto sortedDimensions = llvm::to_vector<3>(Range&: combinedDimensions); |
1085 | for (auto &dims : sortedDimensions) |
1086 | llvm::sort(C&: dims); |
1087 | |
1088 | // Normalize ParallelOp's iteration pattern. |
1089 | SmallVector<Value, 3> normalizedUpperBounds; |
1090 | for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) { |
1091 | OpBuilder::InsertionGuard g2(rewriter); |
1092 | rewriter.setInsertionPoint(loops); |
1093 | Value lb = loops.getLowerBound()[i]; |
1094 | Value ub = loops.getUpperBound()[i]; |
1095 | Value step = loops.getStep()[i]; |
1096 | auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step); |
1097 | normalizedUpperBounds.push_back(Elt: getValueOrCreateConstantIntOp( |
1098 | rewriter, loops.getLoc(), newLoopRange.size)); |
1099 | |
1100 | rewriter.setInsertionPointToStart(loops.getBody()); |
1101 | denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb, |
1102 | step); |
1103 | } |
1104 | |
1105 | // Combine iteration spaces. |
1106 | SmallVector<Value, 3> lowerBounds, upperBounds, steps; |
1107 | auto cst0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1108 | auto cst1 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
1109 | for (auto &sortedDimension : sortedDimensions) { |
1110 | Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
1111 | for (auto idx : sortedDimension) { |
1112 | newUpperBound = rewriter.create<arith::MulIOp>( |
1113 | loc, newUpperBound, normalizedUpperBounds[idx]); |
1114 | } |
1115 | lowerBounds.push_back(Elt: cst0); |
1116 | steps.push_back(Elt: cst1); |
1117 | upperBounds.push_back(Elt: newUpperBound); |
1118 | } |
1119 | |
1120 | // Create new ParallelLoop with conversions to the original induction values. |
1121 | // The loop below uses divisions to get the relevant range of values in the |
1122 | // new induction value that represent each range of the original induction |
1123 | // value. The remainders then determine based on that range, which iteration |
1124 | // of the original induction value this represents. This is a normalized value |
1125 | // that is un-normalized already by the previous logic. |
1126 | auto newPloop = rewriter.create<scf::ParallelOp>( |
1127 | loc, lowerBounds, upperBounds, steps, |
1128 | [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) { |
1129 | for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { |
1130 | Value previous = ploopIVs[i]; |
1131 | unsigned numberCombinedDimensions = combinedDimensions[i].size(); |
1132 | // Iterate over all except the last induction value. |
1133 | for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) { |
1134 | unsigned idx = combinedDimensions[i][j]; |
1135 | |
1136 | // Determine the current induction value's current loop iteration |
1137 | Value iv = insideBuilder.create<arith::RemSIOp>( |
1138 | loc, previous, normalizedUpperBounds[idx]); |
1139 | replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv, |
1140 | loops.getRegion()); |
1141 | |
1142 | // Remove the effect of the current induction value to prepare for |
1143 | // the next value. |
1144 | previous = insideBuilder.create<arith::DivSIOp>( |
1145 | loc, previous, normalizedUpperBounds[idx]); |
1146 | } |
1147 | |
1148 | // The final induction value is just the remaining value. |
1149 | unsigned idx = combinedDimensions[i][0]; |
1150 | replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), |
1151 | previous, loops.getRegion()); |
1152 | } |
1153 | }); |
1154 | |
1155 | // Replace the old loop with the new loop. |
1156 | loops.getBody()->back().erase(); |
1157 | newPloop.getBody()->getOperations().splice( |
1158 | Block::iterator(newPloop.getBody()->back()), |
1159 | loops.getBody()->getOperations()); |
1160 | loops.erase(); |
1161 | } |
1162 | |
1163 | // Hoist the ops within `outer` that appear before `inner`. |
1164 | // Such ops include the ops that have been introduced by parametric tiling. |
1165 | // Ops that come from triangular loops (i.e. that belong to the program slice |
1166 | // rooted at `outer`) and ops that have side effects cannot be hoisted. |
1167 | // Return failure when any op fails to hoist. |
1168 | static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) { |
1169 | SetVector<Operation *> forwardSlice; |
1170 | ForwardSliceOptions options; |
1171 | options.filter = [&inner](Operation *op) { |
1172 | return op != inner.getOperation(); |
1173 | }; |
1174 | getForwardSlice(outer.getInductionVar(), &forwardSlice, options); |
1175 | LogicalResult status = success(); |
1176 | SmallVector<Operation *, 8> toHoist; |
1177 | for (auto &op : outer.getBody()->without_terminator()) { |
1178 | // Stop when encountering the inner loop. |
1179 | if (&op == inner.getOperation()) |
1180 | break; |
1181 | // Skip over non-hoistable ops. |
1182 | if (forwardSlice.count(&op) > 0) { |
1183 | status = failure(); |
1184 | continue; |
1185 | } |
1186 | // Skip intermediate scf::ForOp, these are not considered a failure. |
1187 | if (isa<scf::ForOp>(op)) |
1188 | continue; |
1189 | // Skip other ops with regions. |
1190 | if (op.getNumRegions() > 0) { |
1191 | status = failure(); |
1192 | continue; |
1193 | } |
1194 | // Skip if op has side effects. |
1195 | // TODO: loads to immutable memory regions are ok. |
1196 | if (!isMemoryEffectFree(&op)) { |
1197 | status = failure(); |
1198 | continue; |
1199 | } |
1200 | toHoist.push_back(&op); |
1201 | } |
1202 | auto *outerForOp = outer.getOperation(); |
1203 | for (auto *op : toHoist) |
1204 | op->moveBefore(outerForOp); |
1205 | return status; |
1206 | } |
1207 | |
1208 | // Traverse the interTile and intraTile loops and try to hoist ops such that |
1209 | // bands of perfectly nested loops are isolated. |
1210 | // Return failure if either perfect interTile or perfect intraTile bands cannot |
1211 | // be formed. |
1212 | static LogicalResult tryIsolateBands(const TileLoops &tileLoops) { |
1213 | LogicalResult status = success(); |
1214 | const Loops &interTile = tileLoops.first; |
1215 | const Loops &intraTile = tileLoops.second; |
1216 | auto size = interTile.size(); |
1217 | assert(size == intraTile.size()); |
1218 | if (size <= 1) |
1219 | return success(); |
1220 | for (unsigned s = 1; s < size; ++s) |
1221 | status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s]) |
1222 | : failure(); |
1223 | for (unsigned s = 1; s < size; ++s) |
1224 | status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s]) |
1225 | : failure(); |
1226 | return status; |
1227 | } |
1228 | |
1229 | /// Collect perfectly nested loops starting from `rootForOps`. Loops are |
1230 | /// perfectly nested if each loop is the first and only non-terminator operation |
1231 | /// in the parent loop. Collect at most `maxLoops` loops and append them to |
1232 | /// `forOps`. |
1233 | template <typename T> |
1234 | static void getPerfectlyNestedLoopsImpl( |
1235 | SmallVectorImpl<T> &forOps, T rootForOp, |
1236 | unsigned maxLoops = std::numeric_limits<unsigned>::max()) { |
1237 | for (unsigned i = 0; i < maxLoops; ++i) { |
1238 | forOps.push_back(rootForOp); |
1239 | Block &body = rootForOp.getRegion().front(); |
1240 | if (body.begin() != std::prev(x: body.end(), n: 2)) |
1241 | return; |
1242 | |
1243 | rootForOp = dyn_cast<T>(&body.front()); |
1244 | if (!rootForOp) |
1245 | return; |
1246 | } |
1247 | } |
1248 | |
1249 | static Loops stripmineSink(scf::ForOp forOp, Value factor, |
1250 | ArrayRef<scf::ForOp> targets) { |
1251 | auto originalStep = forOp.getStep(); |
1252 | auto iv = forOp.getInductionVar(); |
1253 | |
1254 | OpBuilder b(forOp); |
1255 | forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor)); |
1256 | |
1257 | Loops innerLoops; |
1258 | for (auto t : targets) { |
1259 | // Save information for splicing ops out of t when done |
1260 | auto begin = t.getBody()->begin(); |
1261 | auto nOps = t.getBody()->getOperations().size(); |
1262 | |
1263 | // Insert newForOp before the terminator of `t`. |
1264 | auto b = OpBuilder::atBlockTerminator((t.getBody())); |
1265 | Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep()); |
1266 | Value ub = |
1267 | b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped); |
1268 | |
1269 | // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. |
1270 | auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep); |
1271 | newForOp.getBody()->getOperations().splice( |
1272 | newForOp.getBody()->getOperations().begin(), |
1273 | t.getBody()->getOperations(), begin, std::next(begin, nOps - 1)); |
1274 | replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(), |
1275 | newForOp.getRegion()); |
1276 | |
1277 | innerLoops.push_back(newForOp); |
1278 | } |
1279 | |
1280 | return innerLoops; |
1281 | } |
1282 | |
1283 | // Stripmines a `forOp` by `factor` and sinks it under a single `target`. |
1284 | // Returns the new for operation, nested immediately under `target`. |
1285 | template <typename SizeType> |
1286 | static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor, |
1287 | scf::ForOp target) { |
1288 | // TODO: Use cheap structural assertions that targets are nested under |
1289 | // forOp and that targets are not nested under each other when DominanceInfo |
1290 | // exposes the capability. It seems overkill to construct a whole function |
1291 | // dominance tree at this point. |
1292 | auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target)); |
1293 | assert(res.size() == 1 && "Expected 1 inner forOp"); |
1294 | return res[0]; |
1295 | } |
1296 | |
1297 | SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps, |
1298 | ArrayRef<Value> sizes, |
1299 | ArrayRef<scf::ForOp> targets) { |
1300 | SmallVector<SmallVector<scf::ForOp, 8>, 8> res; |
1301 | SmallVector<scf::ForOp, 8> currentTargets(targets); |
1302 | for (auto it : llvm::zip(forOps, sizes)) { |
1303 | auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets); |
1304 | res.push_back(step); |
1305 | currentTargets = step; |
1306 | } |
1307 | return res; |
1308 | } |
1309 | |
1310 | Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes, |
1311 | scf::ForOp target) { |
1312 | SmallVector<scf::ForOp, 8> res; |
1313 | for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) |
1314 | res.push_back(llvm::getSingleElement(loops)); |
1315 | return res; |
1316 | } |
1317 | |
1318 | Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) { |
1319 | // Collect perfectly nested loops. If more size values provided than nested |
1320 | // loops available, truncate `sizes`. |
1321 | SmallVector<scf::ForOp, 4> forOps; |
1322 | forOps.reserve(sizes.size()); |
1323 | getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); |
1324 | if (forOps.size() < sizes.size()) |
1325 | sizes = sizes.take_front(N: forOps.size()); |
1326 | |
1327 | return ::tile(forOps, sizes, forOps.back()); |
1328 | } |
1329 | |
1330 | void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops, |
1331 | scf::ForOp root) { |
1332 | getPerfectlyNestedLoopsImpl(nestedLoops, root); |
1333 | } |
1334 | |
1335 | TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, |
1336 | ArrayRef<int64_t> sizes) { |
1337 | // Collect perfectly nested loops. If more size values provided than nested |
1338 | // loops available, truncate `sizes`. |
1339 | SmallVector<scf::ForOp, 4> forOps; |
1340 | forOps.reserve(sizes.size()); |
1341 | getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); |
1342 | if (forOps.size() < sizes.size()) |
1343 | sizes = sizes.take_front(N: forOps.size()); |
1344 | |
1345 | // Compute the tile sizes such that i-th outer loop executes size[i] |
1346 | // iterations. Given that the loop current executes |
1347 | // numIterations = ceildiv((upperBound - lowerBound), step) |
1348 | // iterations, we need to tile with size ceildiv(numIterations, size[i]). |
1349 | SmallVector<Value, 4> tileSizes; |
1350 | tileSizes.reserve(N: sizes.size()); |
1351 | for (unsigned i = 0, e = sizes.size(); i < e; ++i) { |
1352 | assert(sizes[i] > 0 && "expected strictly positive size for strip-mining"); |
1353 | |
1354 | auto forOp = forOps[i]; |
1355 | OpBuilder builder(forOp); |
1356 | auto loc = forOp.getLoc(); |
1357 | Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(), |
1358 | forOp.getLowerBound()); |
1359 | Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep()); |
1360 | Value iterationsPerBlock = |
1361 | ceilDivPositive(builder, loc, numIterations, sizes[i]); |
1362 | tileSizes.push_back(Elt: iterationsPerBlock); |
1363 | } |
1364 | |
1365 | // Call parametric tiling with the given sizes. |
1366 | auto intraTile = tile(forOps, tileSizes, forOps.back()); |
1367 | TileLoops tileLoops = std::make_pair(forOps, intraTile); |
1368 | |
1369 | // TODO: for now we just ignore the result of band isolation. |
1370 | // In the future, mapping decisions may be impacted by the ability to |
1371 | // isolate perfectly nested bands. |
1372 | (void)tryIsolateBands(tileLoops); |
1373 | |
1374 | return tileLoops; |
1375 | } |
1376 | |
1377 | scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, |
1378 | scf::ForallOp source, |
1379 | RewriterBase &rewriter) { |
1380 | unsigned numTargetOuts = target.getNumResults(); |
1381 | unsigned numSourceOuts = source.getNumResults(); |
1382 | |
1383 | // Create fused shared_outs. |
1384 | SmallVector<Value> fusedOuts; |
1385 | llvm::append_range(fusedOuts, target.getOutputs()); |
1386 | llvm::append_range(fusedOuts, source.getOutputs()); |
1387 | |
1388 | // Create a new scf.forall op after the source loop. |
1389 | rewriter.setInsertionPointAfter(source); |
1390 | scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>( |
1391 | source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), |
1392 | source.getMixedStep(), fusedOuts, source.getMapping()); |
1393 | |
1394 | // Map control operands. |
1395 | IRMapping mapping; |
1396 | mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); |
1397 | mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); |
1398 | |
1399 | // Map shared outs. |
1400 | mapping.map(target.getRegionIterArgs(), |
1401 | fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); |
1402 | mapping.map(source.getRegionIterArgs(), |
1403 | fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); |
1404 | |
1405 | // Append everything except the terminator into the fused operation. |
1406 | rewriter.setInsertionPointToStart(fusedLoop.getBody()); |
1407 | for (Operation &op : target.getBody()->without_terminator()) |
1408 | rewriter.clone(op, mapping); |
1409 | for (Operation &op : source.getBody()->without_terminator()) |
1410 | rewriter.clone(op, mapping); |
1411 | |
1412 | // Fuse the old terminator in_parallel ops into the new one. |
1413 | scf::InParallelOp targetTerm = target.getTerminator(); |
1414 | scf::InParallelOp sourceTerm = source.getTerminator(); |
1415 | scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); |
1416 | rewriter.setInsertionPointToStart(fusedTerm.getBody()); |
1417 | for (Operation &op : targetTerm.getYieldingOps()) |
1418 | rewriter.clone(op, mapping); |
1419 | for (Operation &op : sourceTerm.getYieldingOps()) |
1420 | rewriter.clone(op, mapping); |
1421 | |
1422 | // Replace old loops by substituting their uses by results of the fused loop. |
1423 | rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); |
1424 | rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); |
1425 | |
1426 | return fusedLoop; |
1427 | } |
1428 | |
1429 | scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, |
1430 | scf::ForOp source, |
1431 | RewriterBase &rewriter) { |
1432 | unsigned numTargetOuts = target.getNumResults(); |
1433 | unsigned numSourceOuts = source.getNumResults(); |
1434 | |
1435 | // Create fused init_args, with target's init_args before source's init_args. |
1436 | SmallVector<Value> fusedInitArgs; |
1437 | llvm::append_range(fusedInitArgs, target.getInitArgs()); |
1438 | llvm::append_range(fusedInitArgs, source.getInitArgs()); |
1439 | |
1440 | // Create a new scf.for op after the source loop (with scf.yield terminator |
1441 | // (without arguments) only in case its init_args is empty). |
1442 | rewriter.setInsertionPointAfter(source); |
1443 | scf::ForOp fusedLoop = rewriter.create<scf::ForOp>( |
1444 | source.getLoc(), source.getLowerBound(), source.getUpperBound(), |
1445 | source.getStep(), fusedInitArgs); |
1446 | |
1447 | // Map original induction variables and operands to those of the fused loop. |
1448 | IRMapping mapping; |
1449 | mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); |
1450 | mapping.map(target.getRegionIterArgs(), |
1451 | fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); |
1452 | mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); |
1453 | mapping.map(source.getRegionIterArgs(), |
1454 | fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); |
1455 | |
1456 | // Merge target's body into the new (fused) for loop and then source's body. |
1457 | rewriter.setInsertionPointToStart(fusedLoop.getBody()); |
1458 | for (Operation &op : target.getBody()->without_terminator()) |
1459 | rewriter.clone(op, mapping); |
1460 | for (Operation &op : source.getBody()->without_terminator()) |
1461 | rewriter.clone(op, mapping); |
1462 | |
1463 | // Build fused yield results by appropriately mapping original yield operands. |
1464 | SmallVector<Value> yieldResults; |
1465 | for (Value operand : target.getBody()->getTerminator()->getOperands()) |
1466 | yieldResults.push_back(mapping.lookupOrDefault(operand)); |
1467 | for (Value operand : source.getBody()->getTerminator()->getOperands()) |
1468 | yieldResults.push_back(mapping.lookupOrDefault(operand)); |
1469 | if (!yieldResults.empty()) |
1470 | rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults); |
1471 | |
1472 | // Replace old loops by substituting their uses by results of the fused loop. |
1473 | rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); |
1474 | rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); |
1475 | |
1476 | return fusedLoop; |
1477 | } |
1478 | |
1479 | FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter, |
1480 | scf::ForallOp forallOp) { |
1481 | SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound(); |
1482 | SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound(); |
1483 | SmallVector<OpFoldResult> steps = forallOp.getMixedStep(); |
1484 | |
1485 | if (forallOp.isNormalized()) |
1486 | return forallOp; |
1487 | |
1488 | OpBuilder::InsertionGuard g(rewriter); |
1489 | auto loc = forallOp.getLoc(); |
1490 | rewriter.setInsertionPoint(forallOp); |
1491 | SmallVector<OpFoldResult> newUbs; |
1492 | for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) { |
1493 | Range normalizedLoopParams = |
1494 | emitNormalizedLoopBounds(rewriter, loc, lb, ub, step); |
1495 | newUbs.push_back(normalizedLoopParams.size); |
1496 | } |
1497 | (void)foldDynamicIndexList(ofrs&: newUbs); |
1498 | |
1499 | // Use the normalized builder since the lower bounds are always 0 and the |
1500 | // steps are always 1. |
1501 | auto normalizedForallOp = rewriter.create<scf::ForallOp>( |
1502 | loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(), |
1503 | [](OpBuilder &, Location, ValueRange) {}); |
1504 | |
1505 | rewriter.inlineRegionBefore(forallOp.getBodyRegion(), |
1506 | normalizedForallOp.getBodyRegion(), |
1507 | normalizedForallOp.getBodyRegion().begin()); |
1508 | // Remove the original empty block in the new loop. |
1509 | rewriter.eraseBlock(block: &normalizedForallOp.getBodyRegion().back()); |
1510 | |
1511 | rewriter.setInsertionPointToStart(normalizedForallOp.getBody()); |
1512 | // Update the users of the original loop variables. |
1513 | for (auto [idx, iv] : |
1514 | llvm::enumerate(normalizedForallOp.getInductionVars())) { |
1515 | auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]); |
1516 | auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]); |
1517 | denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep); |
1518 | } |
1519 | |
1520 | rewriter.replaceOp(forallOp, normalizedForallOp); |
1521 | return normalizedForallOp; |
1522 | } |
1523 |
Definitions
- replaceLoopNestWithNewYields
- outlineSingleBlockRegion
- outlineIfOp
- getInnermostParallelLoops
- ceilDivPositive
- ceilDivPositive
- getConstantTripCount
- generateUnrolledLoop
- loopUnrollByFactor
- loopUnrollFull
- areInnerBoundsInvariant
- loopUnrollJamByFactor
- emitNormalizedLoopBoundsForIndexType
- emitNormalizedLoopBounds
- denormalizeInductionVariableForIndexType
- denormalizeInductionVariable
- getProductOfIndexes
- getProductOfIntsOrIndexes
- delinearizeInductionVariable
- coalesceLoops
- coalesceLoops
- coalescePerfectlyNestedSCFForLoops
- collapseParallelLoops
- hoistOpsBetween
- tryIsolateBands
- getPerfectlyNestedLoopsImpl
- stripmineSink
- stripmineSink
- tile
- tile
- tilePerfectlyNested
- getPerfectlyNestedLoops
- extractFixedOuterLoops
- fuseIndependentSiblingForallLoops
- fuseIndependentSiblingForLoops
Improve your Profiling and Debugging skills
Find out more