1//===- HoistPadding.cpp - Hoisting for tensor::PadOp ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with hoisting padding operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Analysis/Presburger/IntegerRelation.h"
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Affine/Transforms/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Dialect/Tensor/Utils/Utils.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/IR/AsmState.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/Interfaces/DestinationStyleOpInterface.h"
28#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
29#include "mlir/Transforms/RegionUtils.h"
30#include "llvm/Support/Debug.h"
31
32using llvm::dbgs;
33
34#define DEBUG_TYPE "hoist-padding"
35
36#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
37
38using namespace mlir;
39using namespace mlir::linalg;
40using namespace mlir::linalg::detail;
41
42#ifndef NDEBUG
43static bool debugPrintLoopInShortForm(Operation *op) {
44 AsmState state(op->getParentOfType<func::FuncOp>());
45 (void)state;
46 if (auto forOp = dyn_cast<scf::ForOp>(op)) {
47 forOp.getInductionVar().printAsOperand(dbgs(), state);
48 dbgs() << " @ " << forOp.getOperation();
49 return true;
50 }
51 return false;
52}
53#endif
54
55static void debugPrintBackwardSlice(SetVector<Operation *> &backwardSlice) {
56 LLVM_DEBUG(llvm::interleaveComma(backwardSlice, DBGS() << "--backwardSlice:",
57 [](Operation *op) {
58 dbgs() << "\n";
59 DBGS() << "----";
60 if (debugPrintLoopInShortForm(op)) {
61 dbgs() << "\n";
62 return;
63 }
64 dbgs() << *op << "\n";
65 });
66 DBGS() << "\n";);
67}
68
69/// Return at most nLevels of immediately enclosing scf::ForOp loops.
70/// Stops at the first parent that is not an scf::ForOp.
71/// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
72/// Control-flow and other containing ops with regions are not modeled atm.
73static void
74getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels,
75 SmallVector<scf::ForOp> &reverseEnclosingLoops) {
76 scf::ForOp outermostEnclosingForOp = nullptr;
77 Operation *nextEnclosingOp = padOp->getParentOp();
78 while (nLevels-- > 0 &&
79 (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
80 LLVM_DEBUG(DBGS() << "loops: ";
81 debugPrintLoopInShortForm(outermostEnclosingForOp);
82 dbgs() << "\n");
83 reverseEnclosingLoops.push_back(outermostEnclosingForOp);
84 nextEnclosingOp = outermostEnclosingForOp->getParentOp();
85 }
86}
87
88/// Return at most nLevels of immediately enclosing scf::ForOp loops.
89/// Stops at the first parent that is not an scf::ForOp.
90/// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
91/// Control-flow and other containing ops with regions are not modeled atm.
92static void
93getEnclosingLoopsUntil(tensor::PadOp padOp, scf::ForOp untilLoop,
94 SmallVector<scf::ForOp> &reverseEnclosingLoops) {
95 scf::ForOp outermostEnclosingForOp = nullptr;
96 Operation *nextEnclosingOp = padOp->getParentOp();
97 while (outermostEnclosingForOp != untilLoop &&
98 (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
99 LLVM_DEBUG(DBGS() << "loops: ";
100 debugPrintLoopInShortForm(outermostEnclosingForOp);
101 dbgs() << "\n");
102 reverseEnclosingLoops.push_back(outermostEnclosingForOp);
103 nextEnclosingOp = outermostEnclosingForOp->getParentOp();
104 }
105}
106
107// Get all the ops in the backwards slice starting from `padOp` and that
108// are dominated by the outermost enclosing loop.
109// This also requires tracking ops defining values used in the region but
110// defined above.
111static void computeBackwardSlice(tensor::PadOp padOp,
112 scf::ForOp outermostEnclosingForOp,
113 SetVector<Operation *> &backwardSlice) {
114 DominanceInfo domInfo(outermostEnclosingForOp);
115 BackwardSliceOptions sliceOptions;
116 sliceOptions.filter = [&](Operation *op) {
117 return domInfo.dominates(outermostEnclosingForOp, op) &&
118 !padOp->isProperAncestor(op);
119 };
120 sliceOptions.inclusive = true;
121
122 // First, add the ops required to compute the region to the backwardSlice.
123 SetVector<Value> valuesDefinedAbove;
124 getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
125 valuesDefinedAbove);
126 for (Value v : valuesDefinedAbove) {
127 getBackwardSlice(root: v, backwardSlice: &backwardSlice, options: sliceOptions);
128 }
129 // Then, add the backward slice from padOp itself.
130 getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
131}
132
133//===----------------------------------------------------------------------===//
134// HoistPaddingAnalysis Implementation.
135//===----------------------------------------------------------------------===//
136
137namespace {
138/// Analysis class to support tensor::PadOp hoisting across multiple enclosing
139/// loops. The failure conditions are:
140/// 1. Pad op has a use that is not an input of a LinalgOp.
141/// 2. Pad op does not have a constant padding value.
142/// 3. There is no immediately enclosing scf::ForOp.
143/// 4. The backward slice from the pad op to the scf::ForOp to hoist above
144/// contains an unknown op with non index type operands, a region, or a
145/// memory effect.
146/// 5. The backward slice from the pad op to the scf::ForOp to hoist above is
147/// empty.
148/// 6. The source tensor of pad op is not defined by an extract slice op.
149/// 7. The source tensor of the extract slice op is not defined outside of
150/// the outermost enclosing scf::ForOp.
151/// 8. There is no enclosing scf::ForOp that indexes the padded data.
152/// Other cases succeed and will trigger hoisting of the pad op.
153struct HoistPaddingAnalysis {
154 HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops);
155 HoistPaddingAnalysis(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp);
156
157 bool isValid() { return valid.has_value() && valid.value(); }
158 bool isInvalid() { return valid.has_value() && !valid.value(); }
159
160 /// Footprint of the hoistedPackedTensor, computed from the packingLoops.
161 SmallVector<Value> getHoistedPackedTensorSizes(RewriterBase &rewriter,
162 Location loc) const;
163
164 /// Performs optional hoisting to enable hoist padding to occur. This may be
165 /// necessary when `sliceOp` is not defined outside of the outermost enclosing
166 /// loop we want to hoist above.
167 ///
168 /// Example:
169 /// ```
170 /// %source = linalg.fill(%cst, %arg0)
171 /// // %source is available for packing here!
172 /// scf.for %i
173 /// scf.for %j
174 /// scf.for %k
175 /// %slice = tensor.extract_slice %source [%i, %j]
176 /// %padded_slice = tensor.pad %slice
177 /// ```
178 void enableHoistPadding(RewriterBase &rewriter);
179
180 /// Common analysis builder to finalize the construction of the analysis once
181 /// optional `enableHoistPadding` has run.
182 /// `reverseEnclosingLoops.back()` is the loop to hoist above.
183 void finalizeHoistPaddingAnalysis();
184
185private:
186 /// Encodes whether the analysis is valid and hoisting can proceed.
187 std::optional<bool> valid;
188
189 /// The padOp to hoist.
190 tensor::PadOp opToHoist;
191
192 /// Immediately enclosing loops considered for hoisting padding.
193 SmallVector<scf::ForOp> reverseEnclosingLoops;
194
195 /// Drop any non-index dependencies of `padOp` and `sliceOp` from
196 /// `backwardSlice`. The method follows the use-def chains of the index
197 /// operands consumed by `padOp` and `sliceOp` and drops the operations
198 /// not part of this index computation. Afterwards, the filtered
199 /// `backwardSlice` contains only the loops whose induction variable is
200 /// used, directly or indirectly, to index the padded tensor. The method
201 /// returns failure if the filtered backward slice contains an unexpected
202 /// operation.
203 ///
204 /// Example:
205 /// ```
206 /// %source = linalg.fill(%cst, %arg0)
207 /// scf.for %i
208 /// %unrelated = linalg.fill(%cst, %arg1) // not used to index
209 /// %source! scf.for %j (%arg2 = %unrelated)
210 /// scf.for %k // not used to index
211 /// %source!
212 /// %ubi = affine.min #map(%i)
213 /// %ubj = affine.min #map(%j)
214 /// %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj]
215 /// %padded_slice = tensor.pad %slice
216 /// ```
217 /// dropNonIndexDependencies(%padded_slice, %slice)
218 /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice.
219 LogicalResult dropNonIndexDependencies();
220
221public:
222 /// The outermost loop, determined by `nLevels` above which `padOp` will
223 /// be hoisted.
224 scf::ForOp outermostEnclosingForOp;
225
226 /// Backward slice rooted at `padOp` and nested under
227 /// `outermostEnclosingForOp`.
228 SetVector<Operation *> backwardSlice;
229
230 /// The scf::ForOp immediately enclosing `padOp` such that:
231 /// 1. they are nested under `outermostEnclosingForOp` (inclusive)
232 /// 2. whose induction variable is used, directly or indirectly, in the
233 /// computation of `padOp`.
234 /// The span of these loops determines the footprint of the packed tensor.
235 SmallVector<scf::ForOp> packingLoops;
236
237 /// The ExtractSliceOp that feeds the PadOp we want to hoist.
238 tensor::ExtractSliceOp sliceOp;
239
240 /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`.
241 scf::ForOp padConsumingForOp;
242};
243
244} // namespace
245
246HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops)
247 : valid(std::nullopt), opToHoist(padOp) {
248 // Get at most `numLoops` of immediately enclosing loops.
249 getAtMostNEnclosingLoops(opToHoist, numLoops, reverseEnclosingLoops);
250 if (reverseEnclosingLoops.empty()) {
251 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
252 valid = false;
253 return;
254 }
255 outermostEnclosingForOp = reverseEnclosingLoops.back();
256 sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>();
257 if (!sliceOp) {
258 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
259 valid = false;
260 return;
261 }
262}
263
264HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp,
265 scf::ForOp outermostEnclosingForOp)
266 : valid(std::nullopt), opToHoist(padOp) {
267 // Get enclosing loops until outermostEnclosingForOp.
268 getEnclosingLoopsUntil(opToHoist, outermostEnclosingForOp,
269 reverseEnclosingLoops);
270 if (reverseEnclosingLoops.empty()) {
271 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
272 valid = false;
273 return;
274 }
275 this->outermostEnclosingForOp = reverseEnclosingLoops.back();
276 if (this->outermostEnclosingForOp != outermostEnclosingForOp) {
277 LLVM_DEBUG(DBGS() << "--Unexpected outermost enclosing loop -> Skip\n");
278 valid = false;
279 return;
280 }
281 sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>();
282 if (!sliceOp) {
283 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
284 valid = false;
285 return;
286 }
287}
288
289void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) {
290 if (isInvalid())
291 return;
292 // If the padded data is not yet available before entering the outermost
293 // enclosing loop, try to apply hoisting on this outermost loop.
294 // TODO: we may want finer-grained hoisting of only that particular `sliceOp`.
295 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
296 outermostEnclosingForOp = cast<scf::ForOp>(
297 hoistLoopInvariantSubsets(rewriter, outermostEnclosingForOp));
298 }
299}
300
301void HoistPaddingAnalysis::finalizeHoistPaddingAnalysis() {
302 if (isInvalid())
303 return;
304
305 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
306 LLVM_DEBUG(DBGS() << "--outermostEnclosingForOp:\n"
307 << outermostEnclosingForOp << "\n"
308 << "--sliceOp: " << sliceOp << "\n"
309 << "--sliceOp.getSource(): " << sliceOp.getSource()
310 << "\n");
311 LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n");
312 valid = false;
313 return;
314 }
315 if (sliceOp->hasOneUse()) {
316 padConsumingForOp = dyn_cast<scf::ForOp>(*(sliceOp->getUsers().begin()));
317 }
318
319 // Check the region of `padOp` depends on a constant only. Adding hoisting
320 // support for arbitrary padding regions would require cloning all
321 // dependencies captured by the padding region.
322 Value paddingValue = opToHoist.getConstantPaddingValue();
323 if (!paddingValue ||
324 !isa_and_nonnull<arith::ConstantOp>(paddingValue.getDefiningOp())) {
325 LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> Skip\n");
326 valid = false;
327 return;
328 }
329
330 computeBackwardSlice(opToHoist, outermostEnclosingForOp, backwardSlice);
331 if (backwardSlice.size() <= 1) {
332 valid = false;
333 return;
334 }
335
336 debugPrintBackwardSlice(backwardSlice);
337 // Remove all ops in the backward slice that are not used to index
338 // the padded tensor. In particular, keep `padOp`, `sliceOp`, and
339 // the loop and affine operations used for the index computation.
340 if (failed(result: dropNonIndexDependencies())) {
341 LLVM_DEBUG(DBGS() << "--Cannot dropNonIndexDependencies -> Skip\n");
342 valid = false;
343 return;
344 }
345 debugPrintBackwardSlice(backwardSlice);
346
347 // Add only the loops part of the filtered `backwardSlice` to the
348 // packing loops. All other loops are not used to index the padded
349 // data and consequently access the same data in every loop
350 // iteration. Adding them to the packing loops would increase the
351 // cache footprint of the packed data by storing the same data
352 // multiple times.
353 for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops))
354 if (backwardSlice.contains(forOp))
355 packingLoops.push_back(forOp);
356
357 // TODO: for multiple loops we need to track the use to the innermost loop.
358 if (packingLoops.size() > 1 && padConsumingForOp) {
359 LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> "
360 "Downgrade to 1 loop\n");
361 packingLoops.resize(1);
362 }
363
364 // Note: at this point, packing loops may be empty but we would still like
365 // to hoist the padding if so specified.
366
367 // The analysis is valid and hoisting can occur.
368 valid = true;
369}
370
371LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() {
372 // Set of all values used for index computation.
373 SetVector<Value> indexEdges;
374
375 // Add all index operands of `operation` to `indexEdges`. An index operand
376 // is an operand of type index.
377 auto addIndexOperandsToIndexEdges = [&](Operation *operation) {
378 for (Value operand : operation->getOperands())
379 if (operand.getType().isIndex())
380 indexEdges.insert(X: operand);
381 };
382
383 // Check if any operation result is contained in `indexEdges`.
384 auto hasIndexResult = [&](Operation *operation) {
385 return llvm::any_of(Range: operation->getResults(), P: [&](Value result) {
386 return indexEdges.contains(key: result);
387 });
388 };
389
390 // Starting from `opToHoist` and `sliceOp` walk the use-def edges of index
391 // type in `backwardSlice`. Add the index operands of an operation to
392 // `indexEdges` and remove all operations from `backwardSlice` that are not
393 // part of the index computation.
394 //
395 // Example:
396 // ```
397 // %source = linalg.fill(%cst, %arg0)
398 // scf.for %i
399 // %unrelated = linalg.fill(%cst, %arg1) // not used to index %source!
400 // scf.for %j (%arg2 = %unrelated)
401 // scf.for %k // not used to index %source!
402 // %ubi = affine.min #map(%i)
403 // %ubj = affine.min #map(%j)
404 // %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj]
405 // %padded_slice = tensor.pad %slice
406 // ```
407 // After iterating `backwardSlice` we obtain:
408 // indexEdges = [%i, %j, %ubi, %ubj]
409 // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k]
410 SetVector<Operation *> operationsToRemove;
411 for (Operation *op : llvm::reverse(C&: backwardSlice)) {
412 // Add the index operands of `opToHoist` and `sliceOp` to start the
413 // exploration of the index computation.
414 if (op == opToHoist || op == sliceOp) {
415 addIndexOperandsToIndexEdges(op);
416 continue;
417 }
418 // Add the index operands of the loop if its induction variable is
419 // used for index computation.
420 if (auto forOp = dyn_cast<scf::ForOp>(op)) {
421 if (!hasIndexResult(op) && indexEdges.contains(key: forOp.getInductionVar())) {
422 addIndexOperandsToIndexEdges(op);
423 continue;
424 }
425 }
426 // Add the index operands of all other operations if at least one result
427 // is used for index computation.
428 if (hasIndexResult(op)) {
429 addIndexOperandsToIndexEdges(op);
430 // Check the operands of the remaining operations all have index type.
431 if (llvm::any_of(Range: op->getOperandTypes(),
432 P: [](Type type) { return !type.isIndex(); })) {
433 LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: "
434 << op << " -> Skip\n");
435 return failure();
436 }
437 // Check the remaining operations do not have regions or memory effects.
438 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
439 bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect();
440 if (hasMemoryEffect || op->getNumRegions() != 0) {
441 LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: "
442 << op << " -> Skip\n");
443 return failure();
444 }
445 continue;
446 }
447 // Remove all other operations not used by the index computation. An
448 // exception are constant operations that may be used by `opToHoist`.
449 if (!isa<arith::ConstantOp>(op))
450 operationsToRemove.insert(X: op);
451 }
452 backwardSlice.set_subtract(operationsToRemove);
453 return success();
454}
455
456SmallVector<Value>
457HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
458 Location loc) const {
459 SmallVector<Value> dynamicTensorSizes;
460
461 // Upper bound the packing loop lengths to size the packed tensor. Taking
462 // upper bounds can make the sizes of the packed tensor independent of the
463 // enclosing loops. This independence is a prerequisite for reusing the same
464 // buffer for all enclosing loop iterations and hoisting its allocation out
465 // of the enclosing loops.
466 for (auto forOp : packingLoops) {
467 // Compute an upper bound `ubVal` for the upper bound of `forOp`.
468 FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
469 rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
470 /*stopCondition=*/
471 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
472 if (v == forOp.getUpperBound())
473 return false;
474 // Compute a bound that is independent of any affine op results.
475 Operation *op = v.getDefiningOp();
476 if (!op)
477 return true;
478 return !isa<affine::AffineMinOp, affine::AffineMaxOp,
479 affine::AffineApplyOp>(op);
480 },
481 /*closedUB=*/true);
482 assert(succeeded(loopUb) && "could not get upper bound");
483 Value ubVal = getValueOrCreateConstantIndexOp(rewriter, loc, *loopUb);
484
485 // Compute the maximal packing loop length as (ub - lb).ceilDiv(step) and
486 // store the result to `dynamicTensorSizes`.
487 // TODO: instead of using the lower bound of `forOp` directly, implement a
488 // lower bound computation similar to the upper bound computation.
489 AffineExpr lb, ub, step;
490 bindDims(rewriter.getContext(), lb, ub);
491 bindSymbols(rewriter.getContext(), step);
492 Value res = rewriter.createOrFold<affine::AffineApplyOp>(
493 loc, (ub - lb).ceilDiv(step),
494 ValueRange{forOp.getLowerBound(), ubVal,
495 cast<scf::ForOp>(forOp).getStep()});
496 dynamicTensorSizes.push_back(res);
497 }
498
499 return dynamicTensorSizes;
500}
501
502static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
503 return outer.isDefinedOutsideOfLoop(v) || matchPattern(value: v, pattern: m_Constant());
504}
505
506//===----------------------------------------------------------------------===//
507// buildPackingLoopNest Implementation.
508//===----------------------------------------------------------------------===//
509
510/// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
511/// The returned Value is guaranteed not to depend on any loop comprised in
512/// [`outer`, `forOp`].
513/// Return null if such a loop-independent quantity cannot be computed.
514static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer,
515 scf::ForOp forOp) {
516 MLIRContext *ctx = forOp->getContext();
517 AffineExpr iv, lb, step;
518 bindDims(ctx, exprs&: iv, exprs&: lb);
519 bindSymbols(ctx, exprs&: step);
520 if (!isDefinedOutsideOrConstant(outer, forOp.getLowerBound()) ||
521 !isDefinedOutsideOrConstant(outer, forOp.getStep()))
522 return Value();
523 Value ivVal = forOp.getInductionVar(), lbVal = forOp.getLowerBound(),
524 stepVal = forOp.getStep();
525 auto loc = forOp->getLoc();
526 return rewriter.createOrFold<affine::AffineApplyOp>(
527 loc, (iv - lb).ceilDiv(other: step), ValueRange{ivVal, lbVal, stepVal});
528}
529
530// Build a packing loop nest by iteratively traversing the backward slice and
531// clone the operations, iteratively stepping into the loops that we encounter.
532// The implementation proceeds in a stack-like fashion:
533// 1. Iteratively clone and step into the loops, pushing the
534// `hoistedPackedTensor`
535// deeper in the stack.
536// 2. At the innermost loop level, create a GenericOp if `transposeVector` is
537// non-empty.
538// 3. At the innermost loop level, create a InsertSliceOp.
539// 4. Iteratively pop and yield the result of the InsertSliceOp across the
540// cloned loops.
541static FailureOr<PackingResult> buildPackingLoopNestImpl(
542 RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist,
543 ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType,
544 tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) {
545 SmallVector<OpFoldResult> offsets, sizes, strides;
546 SmallVector<Value> clonedLoopIvs, leadingHoistedPackedTensorIndexings;
547
548 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
549
550 Location loc = opToHoist->getLoc();
551 RankedTensorType paddedTensorType = opToHoist.getResultType();
552 int paddedRank = paddedTensorType.getRank();
553
554 // Step 0. Populate bvm with opToHoist.getSource if relevant.
555 BlockArgument bbArg = dyn_cast<BlockArgument>(opToHoist.getSource());
556 while (bbArg) {
557 auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp());
558 if (!forOp)
559 break;
560 if (forOp != outerLoop && !outerLoop->isAncestor(forOp))
561 break;
562 OpOperand &operand = *forOp.getTiedLoopInit(bbArg);
563 bvm.map(from: bbArg, to: operand.get());
564 bbArg = dyn_cast<BlockArgument>(Val: operand.get());
565 }
566
567 // Step 1. iteratively clone loops and push `hoistedPackedTensor`.
568 Value hoistedPackedTensor = emptyOp.getResult();
569 OpBuilder::InsertionGuard g(rewriter);
570 for (Operation *op : analysis.backwardSlice) {
571 // Specifically sit out in the extract_slice(hoistedPackedTensor) case: this
572 // is the piece we seek to replace.
573 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
574 if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) {
575 LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n");
576 continue;
577 }
578 }
579
580 // Clone all operations except loops which require special handling.
581 auto forOp = dyn_cast<scf::ForOp>(op);
582 if (!forOp) {
583 // We are at the right insertion point within the loop nest.
584 rewriter.clone(op&: *op, mapper&: bvm);
585 continue;
586 }
587
588 // Create a packing loop that takes `hoistedPackedTensor` as iteration
589 // argument.
590 auto clonedForOp = rewriter.create<scf::ForOp>(
591 loc, bvm.lookupOrDefault(forOp.getLowerBound()),
592 bvm.lookupOrDefault(forOp.getUpperBound()),
593 bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
594
595 // Map the induction var, region args and results to the `clonedForOp`.
596 bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
597 bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
598 bvm.map(forOp.getResults(), clonedForOp.getResults());
599 assert(clonedForOp->getNumRegions() == 1);
600 clonedLoopIvs.push_back(Elt: clonedForOp.getInductionVar());
601
602 // Do not insert guard here, we get deeper into the loop nest.
603 rewriter.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
604 Value loopIndependentIterationCount =
605 buildLoopIterationCount(rewriter, outerLoop, clonedForOp);
606
607 // Assert the loop-independent iteration count can be computed.
608 if (!loopIndependentIterationCount)
609 llvm_unreachable("loop independence prerequisite not met");
610 leadingHoistedPackedTensorIndexings.push_back(
611 Elt: loopIndependentIterationCount);
612 hoistedPackedTensor = clonedForOp.getRegionIterArgs().front();
613 }
614
615 // Step 2. Construct offsets, sizes and strides for the innermost level of the
616 // packing loop.
617 int64_t nPackedLoops = clonedLoopIvs.size();
618 // offsets = [clonedLoopIvs, 0 .. 0].
619 offsets =
620 SmallVector<OpFoldResult>{leadingHoistedPackedTensorIndexings.begin(),
621 leadingHoistedPackedTensorIndexings.end()};
622 offsets.append(paddedRank, rewriter.getIndexAttr(0));
623 // sizes = [1 .. 1, transposedShape].
624 sizes = SmallVector<OpFoldResult>(nPackedLoops, rewriter.getIndexAttr(1));
625 for (int64_t sz : transposedTensorType.getShape()) {
626 // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor.
627 if (ShapedType::isDynamic(sz))
628 return failure();
629 sizes.push_back(rewriter.getIndexAttr(sz));
630 }
631 // strides = [1 .. 1].
632 strides = SmallVector<OpFoldResult>(nPackedLoops + paddedRank,
633 rewriter.getIndexAttr(1));
634
635 // Step 3. Optionally transpose the padded tensor.
636 GenericOp maybeTransposeOp;
637 Value paddedTensor = bvm.lookup(opToHoist.getResult());
638 if (!transposeVector.empty()) {
639 Value outputTensor = rewriter.create<tensor::ExtractSliceOp>(
640 loc, transposedTensorType, hoistedPackedTensor, offsets, sizes,
641 strides);
642 maybeTransposeOp = makeTransposeOp(rewriter, loc, paddedTensor,
643 outputTensor, transposeVector);
644 paddedTensor = maybeTransposeOp.getResult(0);
645 }
646
647 // Innermost tensor.insert_slice and yields are optional / need loops.
648 if (nPackedLoops > 0) {
649 // Step 4. Create InsertSliceOp at the innermost loop level, inserting an
650 // optionally transposed padded slice into the packed tensor.
651 Value inserted = rewriter.create<tensor::InsertSliceOp>(
652 loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides);
653
654 // Step 5. Iteratively pop the stack and propagate the yield.
655 Value valueToYield = inserted;
656 for (Value iv : llvm::reverse(C&: clonedLoopIvs)) {
657 auto forOp = scf::getForInductionVarOwner(iv);
658 rewriter.setInsertionPointToEnd(&forOp.getRegion().front());
659 rewriter.create<scf::YieldOp>(loc, valueToYield);
660 valueToYield = forOp.getResult(0);
661 }
662 }
663
664 return PackingResult{
665 offsets,
666 sizes,
667 strides,
668 clonedLoopIvs,
669 leadingHoistedPackedTensorIndexings,
670 maybeTransposeOp,
671 cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp())};
672}
673
674/// Build the packing loop nest required to hoist `opToHoist` above
675/// `outermostEnclosingForOp`.
676/// The loop nest is built just before `outermostEnclosingForOp`.
677static FailureOr<PackingResult> buildPackingLoopNestImpl(
678 RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist,
679 ArrayRef<int64_t> transposeVector, const HoistPaddingAnalysis &analysis) {
680 // Update actual number of loops, which may be smaller.
681 int nPackedLoops = analysis.packingLoops.size();
682 LLVM_DEBUG(DBGS() << "\n";
683 DBGS() << "Func:\n"
684 << *opToHoist->getParentOfType<func::FuncOp>() << "\n";
685 DBGS() << "Start hoisting above " << nPackedLoops << " loops\n");
686
687 Location loc = opToHoist->getLoc();
688 RankedTensorType paddedTensorType = opToHoist.getResultType();
689
690 // Compute the type of the transposed padded tensor.
691 FailureOr<RankedTensorType> transposedTensorType =
692 tensor::computeTransposedType(paddedTensorType, transposeVector);
693 if (failed(transposedTensorType)) {
694 LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n");
695 return failure();
696 }
697
698 // Create the packed tensor<?x?x..? x transposedShape>.
699 SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic);
700 // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor.
701 llvm::append_range(packedShape, transposedTensorType->getShape());
702 auto hoistedPackedTensorType = RankedTensorType::get(
703 packedShape, transposedTensorType->getElementType());
704
705 // Set the insertion point right before the outer loop and start packing.
706 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
707 OpBuilder::InsertionGuard g(rewriter);
708 rewriter.setInsertionPoint(outerLoop);
709 SmallVector<Value> dynamicTensorSizes =
710 analysis.getHoistedPackedTensorSizes(rewriter, loc);
711 auto emptyOp = rewriter.create<tensor::EmptyOp>(
712 loc, hoistedPackedTensorType.getShape(),
713 hoistedPackedTensorType.getElementType(), dynamicTensorSizes);
714
715 return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
716 *transposedTensorType, emptyOp, analysis);
717}
718
719/// Build the packing loop nest required to hoist `opToHoist` above
720/// `outermostEnclosingForOp`.
721/// The loop nest is built just before `outermostEnclosingForOp`.
722FailureOr<PackingResult> mlir::linalg::detail::buildPackingLoopNest(
723 RewriterBase &rewriter, tensor::PadOp opToHoist,
724 scf::ForOp outermostEnclosingForOp, ArrayRef<int64_t> transposeVector) {
725 HoistPaddingAnalysis analysis(opToHoist, outermostEnclosingForOp);
726 analysis.enableHoistPadding(rewriter);
727 analysis.finalizeHoistPaddingAnalysis();
728 if (!analysis.isValid()) {
729 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
730 return failure();
731 }
732 IRMapping bvm;
733 return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
734 analysis);
735}
736
737//===----------------------------------------------------------------------===//
738// hoistPaddingOnTensors Implementation.
739//===----------------------------------------------------------------------===//
740
741/// Return true if we can walk back the use-def chain from `extractSliceOp` to
742/// expectedSource going through DestinationStyleOpInterface inits only.
743/// This is a poor man's analysis that is sufficient to check the extractSliceOp
744/// the matches tensor.pad we want to hoist.
745/// In the future, it will be easier to ensure this with a matching symmetric
746/// tensor.unpad op.
747static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
748 Value expectedSource) {
749 LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp
750 << "\n");
751 LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n");
752 Value source = extractSliceOp.getSource();
753 LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
754 while (source && source != expectedSource) {
755 auto destOp =
756 dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
757 if (!destOp)
758 break;
759 LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
760 source = destOp.getDpsInitOperand(cast<OpResult>(Val&: source).getResultNumber())
761 ->get();
762 }
763 LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");
764 LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");
765 return source == expectedSource;
766}
767
768/// If the original consumer of `outerSliceOp` was a `forOp` (i.e. through an
769/// iter arg), propagate the `hoistedPackedTensor` value through the same iter
770/// arg.
771/// TODO: for multiple loops we need to track the use to the innermost loop.
772///
773/// Match:
774/// ```
775/// %outerSliceOp = tensor.extract_slice ..
776/// %f = scf.for ... iter_args(%arg0 = %outerSliceOp) {
777/// %hoistedPackedTensor = tensor.pad %arg0
778/// %1 = compute %hoistedPackedTensor
779/// %2 = tensor.extract_slice %1
780/// scf.yield %2
781/// }
782/// ```
783///
784/// and rewrite as:
785/// ```
786/// %outerSliceOp = tensor.extract_slice ..
787/// %hoistedPackedTensor = tensor.pad %outerSliceOp
788/// %f = scf.for ... iter_args(%arg0 = %hoistedPackedTensor) {
789/// %1 = compute %arg0
790/// scf.yield %1
791/// }
792/// %2 = tensor.extract_slice %forOp
793/// ```
794///
795/// Return null when no rewrite happened.
796static tensor::ExtractSliceOp
797padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
798 Value hoistedPackedTensor,
799 tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) {
800 LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n");
801 LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: "
802 << paddedValueBeforeHoisting << "\n");
803 OpOperand *pUse = nullptr;
804 for (OpOperand &use : outerSliceOp->getUses()) {
805 if (use.getOwner() == forOp) {
806 assert(!pUse && "Multiple slice uses in the for loop");
807 pUse = &use;
808 }
809 }
810 assert(pUse && "No slice use in the for loop");
811 OpBuilder::InsertionGuard g(rewriter);
812 rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
813
814 unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
815 auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
816 .getDefiningOp<tensor::ExtractSliceOp>();
817 if (!yieldingExtractSliceOp)
818 return tensor::ExtractSliceOp();
819
820 // Poor man's analysis sufficient to ensure extractSlice matches tensor.pad.
821 // In the future, it will be easier to ensure this with a matching symmetric
822 // tensor.unpad op.
823 if (!tracesBackToExpectedValue(yieldingExtractSliceOp,
824 paddedValueBeforeHoisting))
825 return tensor::ExtractSliceOp();
826
827 SmallVector<Value> initArgs = forOp.getInitArgs();
828 initArgs[iterArgNumber] = hoistedPackedTensor;
829 SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues());
830 yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
831
832 int64_t numOriginalForOpResults = initArgs.size();
833 LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
834 << "\n");
835 tensor::ExtractSliceOp extracted;
836 {
837 OpBuilder::InsertionGuard g(rewriter);
838 rewriter.setInsertionPointAfter(forOp);
839 extracted = rewriter.create<tensor::ExtractSliceOp>(
840 hoistedPackedTensor.getLoc(), hoistedPackedTensor,
841 outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
842 outerSliceOp.getMixedStrides());
843 rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
844 }
845 scf::ForOp newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
846 rewriter, initArgs, /*replaceInitOperandUsesInLoop=*/true,
847 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
848 return yieldOperands;
849 }));
850
851 LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()
852 << "\n");
853 LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
854 LLVM_DEBUG(DBGS() << "with result #"
855 << numOriginalForOpResults + iterArgNumber
856 << " of forOp, giving us: " << extracted << "\n");
857 rewriter.startOpModification(op: extracted);
858 extracted.getSourceMutable().assign(
859 newForOp.getResult(numOriginalForOpResults + iterArgNumber));
860 rewriter.finalizeOpModification(op: extracted);
861
862 LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
863 << "\n");
864 LLVM_DEBUG(DBGS() << "with region iter arg #"
865 << numOriginalForOpResults + iterArgNumber << "\n");
866 rewriter.replaceAllUsesWith(
867 paddedValueBeforeHoisting,
868 newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
869
870 return extracted;
871}
872
873/// Produce a tensor extracted from the packingResult. This can be used as a
874/// replacement for `opToHoist` in callers.
875static Value replaceByPackingResult(RewriterBase &rewriter,
876 const IRMapping &bvm,
877 tensor::PadOp opToHoist,
878 RankedTensorType transposedTensorType,
879 const HoistPaddingAnalysis &analysis,
880 const PackingResult &packingResult) {
881 // The replacement occurs under a single insertion point within the original
882 // loop, just before opToHoist.
883 OpBuilder::InsertionGuard g(rewriter);
884 rewriter.setInsertionPoint(opToHoist);
885
886 Location loc = opToHoist->getLoc();
887 RankedTensorType paddedTensorType = opToHoist.getResultType();
888 int paddedRank = paddedTensorType.getRank();
889
890 int64_t nPackedLoops = packingResult.clonedLoopIvs.size();
891 LLVM_DEBUG(DBGS() << "nPackedLoops: " << nPackedLoops << " loops\n");
892
893 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
894 ArrayRef<scf::ForOp> packingLoops = analysis.packingLoops;
895
896 Value hoistedPackedTensor;
897 SmallVector<Value> loopIterationCounts;
898 SmallVector<OpFoldResult> offsets(nPackedLoops + paddedRank,
899 rewriter.getIndexAttr(0));
900 if (nPackedLoops > 0) {
901 loopIterationCounts =
902 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
903 return buildLoopIterationCount(rewriter, outerLoop,
904 cast<scf::ForOp>(loop));
905 }));
906 // Assert all loop iteration counts can be computed.
907 if (llvm ::any_of(Range&: loopIterationCounts, P: [](Value v) { return !v; }))
908 llvm_unreachable("loop independence prerequisite not met");
909
910 // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0].
911 std::copy(first: loopIterationCounts.begin(), last: loopIterationCounts.end(),
912 result: offsets.begin());
913 hoistedPackedTensor =
914 scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front())
915 ->getResult(0);
916 } else {
917 // If no loops were created, this is just hoisting without packing.
918 hoistedPackedTensor = bvm.lookup(opToHoist.getResult());
919 }
920
921 LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n");
922
923 // If the consumer of `padOp` was a `forOp`, propagate through iter args.
924 scf::ForOp forOp = analysis.padConsumingForOp;
925 if (forOp) {
926 return padThroughLoopIterArg(rewriter, opToHoist, hoistedPackedTensor,
927 analysis.sliceOp, forOp);
928 }
929
930 // offsets = [maybe_leading_ivs, 0 .. 0].
931 // sizes = [1 .. 1, transposedShape] (defined above).
932 // strides = [1 .. 1] (defined above)
933 return rewriter.create<tensor::ExtractSliceOp>(
934 loc, transposedTensorType, hoistedPackedTensor, offsets,
935 packingResult.sizes, packingResult.strides);
936}
937
938FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
939 RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops,
940 ArrayRef<int64_t> transposeVector, tensor::PadOp &hoistedOp,
941 SmallVectorImpl<GenericOp> &transposeOps) {
942 LLVM_DEBUG(DBGS() << "\n"; DBGS() << " Try to hoist " << *(opToHoist) << "\n";
943 DBGS() << " by " << numLoops << " loops\n");
944
945 HoistPaddingAnalysis analysis(opToHoist, numLoops);
946 analysis.enableHoistPadding(rewriter);
947 analysis.finalizeHoistPaddingAnalysis();
948 if (!analysis.isValid()) {
949 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
950 return failure();
951 }
952
953 /// Construct the packing loop nest.
954 IRMapping bvm;
955 FailureOr<PackingResult> packingResult = buildPackingLoopNestImpl(
956 rewriter, bvm, opToHoist, transposeVector, analysis);
957 if (failed(result: packingResult)) {
958 LLVM_DEBUG(DBGS() << "--buildPackingLoopNestImpl failed -> Skip\n");
959 return failure();
960 }
961
962 if (!transposeVector.empty())
963 transposeOps.push_back(packingResult->maybeTransposeOp);
964
965 FailureOr<RankedTensorType> transposedTensorType =
966 tensor::computeTransposedType(rankedTensorType: opToHoist.getResultType(), transposeVector);
967 assert(succeeded(transposedTensorType) && "unexpected failure in type");
968
969 // Now the packed tensor is ready, replace the original padding op by a
970 // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
971 Value newResult =
972 replaceByPackingResult(rewriter, bvm, opToHoist, *transposedTensorType,
973 analysis, *packingResult);
974
975 Location loc = opToHoist->getLoc();
976 RankedTensorType paddedTensorType = opToHoist.getResultType();
977 if (!transposeVector.empty()) {
978 OpBuilder::InsertionGuard g(rewriter);
979 rewriter.setInsertionPointAfter(newResult.getDefiningOp());
980 // Transpose the packed tensor back to the original storage order.
981 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
982 loc, paddedTensorType.getShape(), paddedTensorType.getElementType());
983 GenericOp unTransposeOp =
984 makeTransposeOp(rewriter, loc, newResult, emptyTensor, transposeVector);
985 newResult = unTransposeOp.getResult(0);
986 transposeOps.push_back(unTransposeOp);
987 }
988
989 LLVM_DEBUG(DBGS() << "newResult: " << newResult << "\n");
990 LLVM_DEBUG(
991 DBGS() << "After hoisting: "
992 << newResult.getDefiningOp()->getParentOfType<func::FuncOp>()
993 << "\n");
994
995 // Make the newly cloned `opToHoist` available to the caller.
996 hoistedOp = packingResult->hoistedPadOp;
997
998 LLVM_DEBUG(DBGS() << "--SUCCESS\n");
999 return newResult;
1000}
1001
1002FailureOr<Value>
1003mlir::linalg::hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops,
1004 ArrayRef<int64_t> transposeVector,
1005 tensor::PadOp &hoistedOp,
1006 SmallVectorImpl<GenericOp> &transposeOps) {
1007 IRRewriter rewriter(opToHoist.getContext());
1008 return hoistPaddingOnTensors(rewriter, opToHoist, numLoops, transposeVector,
1009 hoistedOp, transposeOps);
1010}
1011

source code of mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp