1//===- HoistPadding.cpp - Hoisting for tensor::PadOp ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with hoisting padding operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Analysis/Presburger/IntegerRelation.h"
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Affine/Transforms/Transforms.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
20#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Dialect/Tensor/Utils/Utils.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/IR/AsmState.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/Interfaces/DestinationStyleOpInterface.h"
28#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
29#include "mlir/Transforms/RegionUtils.h"
30#include "llvm/Support/Debug.h"
31
32using llvm::dbgs;
33
34#define DEBUG_TYPE "hoist-padding"
35
36#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
37
38using namespace mlir;
39using namespace mlir::linalg;
40using namespace mlir::linalg::detail;
41
42#ifndef NDEBUG
43static bool debugPrintLoopInShortForm(Operation *op) {
44 AsmState state(op->getParentOfType<func::FuncOp>());
45 (void)state;
46 if (auto forOp = dyn_cast<scf::ForOp>(op)) {
47 forOp.getInductionVar().printAsOperand(dbgs(), state);
48 dbgs() << " @ " << forOp.getOperation();
49 return true;
50 }
51 return false;
52}
53#endif
54
55static void debugPrintBackwardSlice(SetVector<Operation *> &backwardSlice) {
56 LLVM_DEBUG(llvm::interleaveComma(backwardSlice, DBGS() << "--backwardSlice:",
57 [](Operation *op) {
58 dbgs() << "\n";
59 DBGS() << "----";
60 if (debugPrintLoopInShortForm(op)) {
61 dbgs() << "\n";
62 return;
63 }
64 dbgs() << *op << "\n";
65 });
66 DBGS() << "\n";);
67}
68
69/// Return at most nLevels of immediately enclosing scf::ForOp loops.
70/// Stops at the first parent that is not an scf::ForOp.
71/// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
72/// Control-flow and other containing ops with regions are not modeled atm.
73static void
74getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels,
75 SmallVector<scf::ForOp> &reverseEnclosingLoops) {
76 scf::ForOp outermostEnclosingForOp = nullptr;
77 Operation *nextEnclosingOp = padOp->getParentOp();
78 while (nLevels-- > 0 &&
79 (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
80 LLVM_DEBUG(DBGS() << "loops: ";
81 debugPrintLoopInShortForm(outermostEnclosingForOp);
82 dbgs() << "\n");
83 reverseEnclosingLoops.push_back(outermostEnclosingForOp);
84 nextEnclosingOp = outermostEnclosingForOp->getParentOp();
85 }
86}
87
88/// Return at most nLevels of immediately enclosing scf::ForOp loops.
89/// Stops at the first parent that is not an scf::ForOp.
90/// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
91/// Control-flow and other containing ops with regions are not modeled atm.
92static void
93getEnclosingLoopsUntil(tensor::PadOp padOp, scf::ForOp untilLoop,
94 SmallVector<scf::ForOp> &reverseEnclosingLoops) {
95 scf::ForOp outermostEnclosingForOp = nullptr;
96 Operation *nextEnclosingOp = padOp->getParentOp();
97 while (outermostEnclosingForOp != untilLoop &&
98 (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
99 LLVM_DEBUG(DBGS() << "loops: ";
100 debugPrintLoopInShortForm(outermostEnclosingForOp);
101 dbgs() << "\n");
102 reverseEnclosingLoops.push_back(outermostEnclosingForOp);
103 nextEnclosingOp = outermostEnclosingForOp->getParentOp();
104 }
105}
106
107// Get all the ops in the backwards slice starting from `padOp` and that
108// are dominated by the outermost enclosing loop.
109// This also requires tracking ops defining values used in the region but
110// defined above.
111static void computeBackwardSlice(tensor::PadOp padOp,
112 scf::ForOp outermostEnclosingForOp,
113 SetVector<Operation *> &backwardSlice) {
114 DominanceInfo domInfo(outermostEnclosingForOp);
115 BackwardSliceOptions sliceOptions;
116 sliceOptions.filter = [&](Operation *op) {
117 return domInfo.dominates(outermostEnclosingForOp, op) &&
118 !padOp->isProperAncestor(op);
119 };
120 sliceOptions.inclusive = true;
121
122 // First, add the ops required to compute the region to the backwardSlice.
123 SetVector<Value> valuesDefinedAbove;
124 getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
125 valuesDefinedAbove);
126 for (Value v : valuesDefinedAbove) {
127 LogicalResult result = getBackwardSlice(root: v, backwardSlice: &backwardSlice, options: sliceOptions);
128 assert(result.succeeded() && "expected a backward slice");
129 (void)result;
130 }
131 // Then, add the backward slice from padOp itself.
132 LogicalResult result =
133 getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
134 assert(result.succeeded() && "expected a backward slice");
135 (void)result;
136}
137
138//===----------------------------------------------------------------------===//
139// HoistPaddingAnalysis Implementation.
140//===----------------------------------------------------------------------===//
141
142namespace {
143/// Analysis class to support tensor::PadOp hoisting across multiple enclosing
144/// loops. The failure conditions are:
145/// 1. Pad op has a use that is not an input of a LinalgOp.
146/// 2. Pad op does not have a constant padding value.
147/// 3. There is no immediately enclosing scf::ForOp.
148/// 4. The backward slice from the pad op to the scf::ForOp to hoist above
149/// contains an unknown op with non index type operands, a region, or a
150/// memory effect.
151/// 5. The backward slice from the pad op to the scf::ForOp to hoist above is
152/// empty.
153/// 6. The source tensor of pad op is not defined by an extract slice op.
154/// 7. The source tensor of the extract slice op is not defined outside of
155/// the outermost enclosing scf::ForOp.
156/// 8. There is no enclosing scf::ForOp that indexes the padded data.
157/// Other cases succeed and will trigger hoisting of the pad op.
158struct HoistPaddingAnalysis {
159 HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops);
160 HoistPaddingAnalysis(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp);
161
162 bool isValid() { return valid.has_value() && valid.value(); }
163 bool isInvalid() { return valid.has_value() && !valid.value(); }
164
165 /// Footprint of the hoistedPackedTensor, computed from the packingLoops.
166 SmallVector<Value> getHoistedPackedTensorSizes(RewriterBase &rewriter,
167 Location loc) const;
168
169 /// Performs optional hoisting to enable hoist padding to occur. This may be
170 /// necessary when `sliceOp` is not defined outside of the outermost enclosing
171 /// loop we want to hoist above.
172 ///
173 /// Example:
174 /// ```
175 /// %source = linalg.fill(%cst, %arg0)
176 /// // %source is available for packing here!
177 /// scf.for %i
178 /// scf.for %j
179 /// scf.for %k
180 /// %slice = tensor.extract_slice %source [%i, %j]
181 /// %padded_slice = tensor.pad %slice
182 /// ```
183 void enableHoistPadding(RewriterBase &rewriter);
184
185 /// Common analysis builder to finalize the construction of the analysis once
186 /// optional `enableHoistPadding` has run.
187 /// `reverseEnclosingLoops.back()` is the loop to hoist above.
188 void finalizeHoistPaddingAnalysis();
189
190private:
191 /// Encodes whether the analysis is valid and hoisting can proceed.
192 std::optional<bool> valid;
193
194 /// The padOp to hoist.
195 tensor::PadOp opToHoist;
196
197 /// Immediately enclosing loops considered for hoisting padding.
198 SmallVector<scf::ForOp> reverseEnclosingLoops;
199
200 /// Drop any non-index dependencies of `padOp` and `sliceOp` from
201 /// `backwardSlice`. The method follows the use-def chains of the index
202 /// operands consumed by `padOp` and `sliceOp` and drops the operations
203 /// not part of this index computation. Afterwards, the filtered
204 /// `backwardSlice` contains only the loops whose induction variable is
205 /// used, directly or indirectly, to index the padded tensor. The method
206 /// returns failure if the filtered backward slice contains an unexpected
207 /// operation.
208 ///
209 /// Example:
210 /// ```
211 /// %source = linalg.fill(%cst, %arg0)
212 /// scf.for %i
213 /// %unrelated = linalg.fill(%cst, %arg1) // not used to index
214 /// %source! scf.for %j (%arg2 = %unrelated)
215 /// scf.for %k // not used to index
216 /// %source!
217 /// %ubi = affine.min #map(%i)
218 /// %ubj = affine.min #map(%j)
219 /// %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj]
220 /// %padded_slice = tensor.pad %slice
221 /// ```
222 /// dropNonIndexDependencies(%padded_slice, %slice)
223 /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice.
224 LogicalResult dropNonIndexDependencies();
225
226public:
227 /// The outermost loop, determined by `nLevels` above which `padOp` will
228 /// be hoisted.
229 scf::ForOp outermostEnclosingForOp;
230
231 /// Backward slice rooted at `padOp` and nested under
232 /// `outermostEnclosingForOp`.
233 SetVector<Operation *> backwardSlice;
234
235 /// The scf::ForOp immediately enclosing `padOp` such that:
236 /// 1. they are nested under `outermostEnclosingForOp` (inclusive)
237 /// 2. whose induction variable is used, directly or indirectly, in the
238 /// computation of `padOp`.
239 /// The span of these loops determines the footprint of the packed tensor.
240 SmallVector<scf::ForOp> packingLoops;
241
242 /// The ExtractSliceOp that feeds the PadOp we want to hoist.
243 tensor::ExtractSliceOp sliceOp;
244
245 /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`.
246 scf::ForOp padConsumingForOp;
247};
248
249} // namespace
250
251HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops)
252 : valid(std::nullopt), opToHoist(padOp) {
253 // Get at most `numLoops` of immediately enclosing loops.
254 getAtMostNEnclosingLoops(opToHoist, numLoops, reverseEnclosingLoops);
255 if (reverseEnclosingLoops.empty()) {
256 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
257 valid = false;
258 return;
259 }
260 outermostEnclosingForOp = reverseEnclosingLoops.back();
261 sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>();
262 if (!sliceOp) {
263 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
264 valid = false;
265 return;
266 }
267}
268
269HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp,
270 scf::ForOp outermostEnclosingForOp)
271 : valid(std::nullopt), opToHoist(padOp) {
272 // Get enclosing loops until outermostEnclosingForOp.
273 getEnclosingLoopsUntil(opToHoist, outermostEnclosingForOp,
274 reverseEnclosingLoops);
275 if (reverseEnclosingLoops.empty()) {
276 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
277 valid = false;
278 return;
279 }
280 this->outermostEnclosingForOp = reverseEnclosingLoops.back();
281 if (this->outermostEnclosingForOp != outermostEnclosingForOp) {
282 LLVM_DEBUG(DBGS() << "--Unexpected outermost enclosing loop -> Skip\n");
283 valid = false;
284 return;
285 }
286 sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>();
287 if (!sliceOp) {
288 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
289 valid = false;
290 return;
291 }
292}
293
294void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) {
295 if (isInvalid())
296 return;
297 // If the padded data is not yet available before entering the outermost
298 // enclosing loop, try to apply hoisting on this outermost loop.
299 // TODO: we may want finer-grained hoisting of only that particular `sliceOp`.
300 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
301 outermostEnclosingForOp = cast<scf::ForOp>(
302 hoistLoopInvariantSubsets(rewriter, outermostEnclosingForOp));
303 }
304}
305
306void HoistPaddingAnalysis::finalizeHoistPaddingAnalysis() {
307 if (isInvalid())
308 return;
309
310 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
311 LLVM_DEBUG(DBGS() << "--outermostEnclosingForOp:\n"
312 << outermostEnclosingForOp << "\n"
313 << "--sliceOp: " << sliceOp << "\n"
314 << "--sliceOp.getSource(): " << sliceOp.getSource()
315 << "\n");
316 LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n");
317 valid = false;
318 return;
319 }
320 if (sliceOp->hasOneUse()) {
321 padConsumingForOp = dyn_cast<scf::ForOp>(*(sliceOp->getUsers().begin()));
322 }
323
324 // Check the region of `padOp` depends on a constant only. Adding hoisting
325 // support for arbitrary padding regions would require cloning all
326 // dependencies captured by the padding region.
327 Value paddingValue = opToHoist.getConstantPaddingValue();
328 if (!paddingValue ||
329 !isa_and_nonnull<arith::ConstantOp>(paddingValue.getDefiningOp())) {
330 LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> Skip\n");
331 valid = false;
332 return;
333 }
334
335 computeBackwardSlice(opToHoist, outermostEnclosingForOp, backwardSlice);
336 if (backwardSlice.size() <= 1) {
337 valid = false;
338 return;
339 }
340
341 debugPrintBackwardSlice(backwardSlice);
342 // Remove all ops in the backward slice that are not used to index
343 // the padded tensor. In particular, keep `padOp`, `sliceOp`, and
344 // the loop and affine operations used for the index computation.
345 if (failed(Result: dropNonIndexDependencies())) {
346 LLVM_DEBUG(DBGS() << "--Cannot dropNonIndexDependencies -> Skip\n");
347 valid = false;
348 return;
349 }
350 debugPrintBackwardSlice(backwardSlice);
351
352 // Add only the loops part of the filtered `backwardSlice` to the
353 // packing loops. All other loops are not used to index the padded
354 // data and consequently access the same data in every loop
355 // iteration. Adding them to the packing loops would increase the
356 // cache footprint of the packed data by storing the same data
357 // multiple times.
358 for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops))
359 if (backwardSlice.contains(forOp))
360 packingLoops.push_back(forOp);
361
362 // TODO: for multiple loops we need to track the use to the innermost loop.
363 if (packingLoops.size() > 1 && padConsumingForOp) {
364 LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> "
365 "Downgrade to 1 loop\n");
366 packingLoops.resize(1);
367 }
368
369 // Note: at this point, packing loops may be empty but we would still like
370 // to hoist the padding if so specified.
371
372 // The analysis is valid and hoisting can occur.
373 valid = true;
374}
375
376LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() {
377 // Set of all values used for index computation.
378 SetVector<Value> indexEdges;
379
380 // Add all index operands of `operation` to `indexEdges`. An index operand
381 // is an operand of type index.
382 auto addIndexOperandsToIndexEdges = [&](Operation *operation) {
383 for (Value operand : operation->getOperands())
384 if (operand.getType().isIndex())
385 indexEdges.insert(X: operand);
386 };
387
388 // Check if any operation result is contained in `indexEdges`.
389 auto hasIndexResult = [&](Operation *operation) {
390 return llvm::any_of(Range: operation->getResults(), P: [&](Value result) {
391 return indexEdges.contains(key: result);
392 });
393 };
394
395 // Starting from `opToHoist` and `sliceOp` walk the use-def edges of index
396 // type in `backwardSlice`. Add the index operands of an operation to
397 // `indexEdges` and remove all operations from `backwardSlice` that are not
398 // part of the index computation.
399 //
400 // Example:
401 // ```
402 // %source = linalg.fill(%cst, %arg0)
403 // scf.for %i
404 // %unrelated = linalg.fill(%cst, %arg1) // not used to index %source!
405 // scf.for %j (%arg2 = %unrelated)
406 // scf.for %k // not used to index %source!
407 // %ubi = affine.min #map(%i)
408 // %ubj = affine.min #map(%j)
409 // %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj]
410 // %padded_slice = tensor.pad %slice
411 // ```
412 // After iterating `backwardSlice` we obtain:
413 // indexEdges = [%i, %j, %ubi, %ubj]
414 // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k]
415 SetVector<Operation *> operationsToRemove;
416 for (Operation *op : llvm::reverse(C&: backwardSlice)) {
417 // Add the index operands of `opToHoist` and `sliceOp` to start the
418 // exploration of the index computation.
419 if (op == opToHoist || op == sliceOp) {
420 addIndexOperandsToIndexEdges(op);
421 continue;
422 }
423 // Add the index operands of the loop if its induction variable is
424 // used for index computation.
425 if (auto forOp = dyn_cast<scf::ForOp>(op)) {
426 if (!hasIndexResult(op) && indexEdges.contains(key: forOp.getInductionVar())) {
427 addIndexOperandsToIndexEdges(op);
428 continue;
429 }
430 }
431 // Add the index operands of all other operations if at least one result
432 // is used for index computation.
433 if (hasIndexResult(op)) {
434 addIndexOperandsToIndexEdges(op);
435 // Check the operands of the remaining operations all have index type.
436 if (llvm::any_of(Range: op->getOperandTypes(),
437 P: [](Type type) { return !type.isIndex(); })) {
438 LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: "
439 << op << " -> Skip\n");
440 return failure();
441 }
442 // Check the remaining operations do not have regions or memory effects.
443 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
444 bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect();
445 if (hasMemoryEffect || op->getNumRegions() != 0) {
446 LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: "
447 << op << " -> Skip\n");
448 return failure();
449 }
450 continue;
451 }
452 // Remove all other operations not used by the index computation. An
453 // exception are constant operations that may be used by `opToHoist`.
454 if (!isa<arith::ConstantOp>(op))
455 operationsToRemove.insert(X: op);
456 }
457 backwardSlice.set_subtract(operationsToRemove);
458 return success();
459}
460
461SmallVector<Value>
462HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
463 Location loc) const {
464 SmallVector<Value> dynamicTensorSizes;
465
466 // Upper bound the packing loop lengths to size the packed tensor. Taking
467 // upper bounds can make the sizes of the packed tensor independent of the
468 // enclosing loops. This independence is a prerequisite for reusing the same
469 // buffer for all enclosing loop iterations and hoisting its allocation out
470 // of the enclosing loops.
471 for (auto forOp : packingLoops) {
472 // Compute an upper bound `ubVal` for the upper bound of `forOp`.
473 FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
474 rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
475 /*stopCondition=*/
476 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
477 if (v == forOp.getUpperBound())
478 return false;
479 // Compute a bound that is independent of any affine op results.
480 Operation *op = v.getDefiningOp();
481 if (!op)
482 return true;
483 return !isa<affine::AffineMinOp, affine::AffineMaxOp,
484 affine::AffineApplyOp>(op);
485 },
486 /*closedUB=*/true);
487 assert(succeeded(loopUb) && "could not get upper bound");
488 Value ubVal = getValueOrCreateConstantIndexOp(rewriter, loc, *loopUb);
489
490 // Compute the maximal packing loop length as (ub - lb).ceilDiv(step) and
491 // store the result to `dynamicTensorSizes`.
492 // TODO: instead of using the lower bound of `forOp` directly, implement a
493 // lower bound computation similar to the upper bound computation.
494 AffineExpr lb, ub, step;
495 bindDims(rewriter.getContext(), lb, ub);
496 bindSymbols(rewriter.getContext(), step);
497 Value res = rewriter.createOrFold<affine::AffineApplyOp>(
498 loc, (ub - lb).ceilDiv(step),
499 ValueRange{forOp.getLowerBound(), ubVal,
500 cast<scf::ForOp>(forOp).getStep()});
501 dynamicTensorSizes.push_back(res);
502 }
503
504 return dynamicTensorSizes;
505}
506
507static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
508 return outer.isDefinedOutsideOfLoop(v) || matchPattern(value: v, pattern: m_Constant());
509}
510
511//===----------------------------------------------------------------------===//
512// buildPackingLoopNest Implementation.
513//===----------------------------------------------------------------------===//
514
515/// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
516/// The returned Value is guaranteed not to depend on any loop comprised in
517/// [`outer`, `forOp`].
518/// Return null if such a loop-independent quantity cannot be computed.
519static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer,
520 scf::ForOp forOp) {
521 MLIRContext *ctx = forOp->getContext();
522 AffineExpr iv, lb, step;
523 bindDims(ctx, exprs&: iv, exprs&: lb);
524 bindSymbols(ctx, exprs&: step);
525 if (!isDefinedOutsideOrConstant(outer, forOp.getLowerBound()) ||
526 !isDefinedOutsideOrConstant(outer, forOp.getStep()))
527 return Value();
528 Value ivVal = forOp.getInductionVar(), lbVal = forOp.getLowerBound(),
529 stepVal = forOp.getStep();
530 auto loc = forOp->getLoc();
531 return rewriter.createOrFold<affine::AffineApplyOp>(
532 loc, (iv - lb).ceilDiv(other: step), ValueRange{ivVal, lbVal, stepVal});
533}
534
535// Build a packing loop nest by iteratively traversing the backward slice and
536// clone the operations, iteratively stepping into the loops that we encounter.
537// The implementation proceeds in a stack-like fashion:
538// 1. Iteratively clone and step into the loops, pushing the
539// `hoistedPackedTensor`
540// deeper in the stack.
541// 2. At the innermost loop level, create a GenericOp if `transposeVector` is
542// non-empty.
543// 3. At the innermost loop level, create a InsertSliceOp.
544// 4. Iteratively pop and yield the result of the InsertSliceOp across the
545// cloned loops.
546static FailureOr<PackingResult> buildPackingLoopNestImpl(
547 RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist,
548 ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType,
549 tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) {
550 SmallVector<OpFoldResult> offsets, sizes, strides;
551 SmallVector<Value> clonedLoopIvs, leadingHoistedPackedTensorIndexings;
552
553 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
554
555 Location loc = opToHoist->getLoc();
556 RankedTensorType paddedTensorType = opToHoist.getResultType();
557 int paddedRank = paddedTensorType.getRank();
558
559 // Step 0. Populate bvm with opToHoist.getSource if relevant.
560 BlockArgument bbArg = dyn_cast<BlockArgument>(opToHoist.getSource());
561 while (bbArg) {
562 auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp());
563 if (!forOp)
564 break;
565 if (forOp != outerLoop && !outerLoop->isAncestor(forOp))
566 break;
567 OpOperand &operand = *forOp.getTiedLoopInit(bbArg);
568 bvm.map(from: bbArg, to: operand.get());
569 bbArg = dyn_cast<BlockArgument>(Val: operand.get());
570 }
571
572 // Step 1. iteratively clone loops and push `hoistedPackedTensor`.
573 Value hoistedPackedTensor = emptyOp.getResult();
574 OpBuilder::InsertionGuard g(rewriter);
575 for (Operation *op : analysis.backwardSlice) {
576 // Specifically sit out in the extract_slice(hoistedPackedTensor) case: this
577 // is the piece we seek to replace.
578 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
579 if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) {
580 LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n");
581 continue;
582 }
583 }
584
585 // Clone all operations except loops which require special handling.
586 auto forOp = dyn_cast<scf::ForOp>(op);
587 if (!forOp) {
588 // We are at the right insertion point within the loop nest.
589 rewriter.clone(op&: *op, mapper&: bvm);
590 continue;
591 }
592
593 // Create a packing loop that takes `hoistedPackedTensor` as iteration
594 // argument.
595 auto clonedForOp = rewriter.create<scf::ForOp>(
596 loc, bvm.lookupOrDefault(forOp.getLowerBound()),
597 bvm.lookupOrDefault(forOp.getUpperBound()),
598 bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
599
600 // Map the induction var, region args and results to the `clonedForOp`.
601 bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
602 bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
603 bvm.map(forOp.getResults(), clonedForOp.getResults());
604 assert(clonedForOp->getNumRegions() == 1);
605 clonedLoopIvs.push_back(Elt: clonedForOp.getInductionVar());
606
607 // Do not insert guard here, we get deeper into the loop nest.
608 rewriter.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
609 Value loopIndependentIterationCount =
610 buildLoopIterationCount(rewriter, outerLoop, clonedForOp);
611
612 // Assert the loop-independent iteration count can be computed.
613 if (!loopIndependentIterationCount)
614 llvm_unreachable("loop independence prerequisite not met");
615 leadingHoistedPackedTensorIndexings.push_back(
616 Elt: loopIndependentIterationCount);
617 hoistedPackedTensor = clonedForOp.getRegionIterArgs().front();
618 }
619
620 // Step 2. Construct offsets, sizes and strides for the innermost level of the
621 // packing loop.
622 int64_t nPackedLoops = clonedLoopIvs.size();
623 // offsets = [clonedLoopIvs, 0 .. 0].
624 offsets =
625 SmallVector<OpFoldResult>{leadingHoistedPackedTensorIndexings.begin(),
626 leadingHoistedPackedTensorIndexings.end()};
627 offsets.append(paddedRank, rewriter.getIndexAttr(0));
628 // sizes = [1 .. 1, transposedShape].
629 sizes = SmallVector<OpFoldResult>(nPackedLoops, rewriter.getIndexAttr(1));
630 for (int64_t sz : transposedTensorType.getShape()) {
631 // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor.
632 if (ShapedType::isDynamic(sz))
633 return failure();
634 sizes.push_back(rewriter.getIndexAttr(sz));
635 }
636 // strides = [1 .. 1].
637 strides = SmallVector<OpFoldResult>(nPackedLoops + paddedRank,
638 rewriter.getIndexAttr(1));
639
640 // Step 3. Optionally transpose the padded tensor.
641 TransposeOp maybeTransposeOp;
642 Value paddedTensor = bvm.lookup(opToHoist.getResult());
643 if (!transposeVector.empty()) {
644 Value outputTensor = rewriter.create<tensor::ExtractSliceOp>(
645 loc, transposedTensorType, hoistedPackedTensor, offsets, sizes,
646 strides);
647 maybeTransposeOp = rewriter.create<linalg::TransposeOp>(
648 loc, paddedTensor, outputTensor, transposeVector);
649 paddedTensor = maybeTransposeOp.getResult()[0];
650 }
651
652 // Innermost tensor.insert_slice and yields are optional / need loops.
653 if (nPackedLoops > 0) {
654 // Step 4. Create InsertSliceOp at the innermost loop level, inserting an
655 // optionally transposed padded slice into the packed tensor.
656 Value inserted = rewriter.create<tensor::InsertSliceOp>(
657 loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides);
658
659 // Step 5. Iteratively pop the stack and propagate the yield.
660 Value valueToYield = inserted;
661 for (Value iv : llvm::reverse(C&: clonedLoopIvs)) {
662 auto forOp = scf::getForInductionVarOwner(iv);
663 rewriter.setInsertionPointToEnd(&forOp.getRegion().front());
664 rewriter.create<scf::YieldOp>(loc, valueToYield);
665 valueToYield = forOp.getResult(0);
666 }
667 }
668
669 return PackingResult{
670 offsets,
671 sizes,
672 strides,
673 clonedLoopIvs,
674 leadingHoistedPackedTensorIndexings,
675 maybeTransposeOp,
676 cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp())};
677}
678
679/// Build the packing loop nest required to hoist `opToHoist` above
680/// `outermostEnclosingForOp`.
681/// The loop nest is built just before `outermostEnclosingForOp`.
682static FailureOr<PackingResult> buildPackingLoopNestImpl(
683 RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist,
684 ArrayRef<int64_t> transposeVector, const HoistPaddingAnalysis &analysis) {
685 // Update actual number of loops, which may be smaller.
686 int nPackedLoops = analysis.packingLoops.size();
687 LLVM_DEBUG(DBGS() << "\n";
688 DBGS() << "Func:\n"
689 << *opToHoist->getParentOfType<func::FuncOp>() << "\n";
690 DBGS() << "Start hoisting above " << nPackedLoops << " loops\n");
691
692 Location loc = opToHoist->getLoc();
693 RankedTensorType paddedTensorType = opToHoist.getResultType();
694
695 // Compute the type of the transposed padded tensor.
696 FailureOr<RankedTensorType> transposedTensorType =
697 tensor::computeTransposedType(paddedTensorType, transposeVector);
698 if (failed(transposedTensorType)) {
699 LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n");
700 return failure();
701 }
702
703 // Create the packed tensor<?x?x..? x transposedShape>.
704 SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic);
705 // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor.
706 llvm::append_range(packedShape, transposedTensorType->getShape());
707 auto hoistedPackedTensorType = RankedTensorType::get(
708 packedShape, transposedTensorType->getElementType());
709
710 // Set the insertion point right before the outer loop and start packing.
711 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
712 OpBuilder::InsertionGuard g(rewriter);
713 rewriter.setInsertionPoint(outerLoop);
714 SmallVector<Value> dynamicTensorSizes =
715 analysis.getHoistedPackedTensorSizes(rewriter, loc);
716 auto emptyOp = rewriter.create<tensor::EmptyOp>(
717 loc, hoistedPackedTensorType.getShape(),
718 hoistedPackedTensorType.getElementType(), dynamicTensorSizes);
719
720 return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
721 *transposedTensorType, emptyOp, analysis);
722}
723
724/// Build the packing loop nest required to hoist `opToHoist` above
725/// `outermostEnclosingForOp`.
726/// The loop nest is built just before `outermostEnclosingForOp`.
727FailureOr<PackingResult> mlir::linalg::detail::buildPackingLoopNest(
728 RewriterBase &rewriter, tensor::PadOp opToHoist,
729 scf::ForOp outermostEnclosingForOp, ArrayRef<int64_t> transposeVector) {
730 HoistPaddingAnalysis analysis(opToHoist, outermostEnclosingForOp);
731 analysis.enableHoistPadding(rewriter);
732 analysis.finalizeHoistPaddingAnalysis();
733 if (!analysis.isValid()) {
734 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
735 return failure();
736 }
737 IRMapping bvm;
738 return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
739 analysis);
740}
741
742//===----------------------------------------------------------------------===//
743// hoistPaddingOnTensors Implementation.
744//===----------------------------------------------------------------------===//
745
746/// Return true if we can walk back the use-def chain from `extractSliceOp` to
747/// expectedSource going through DestinationStyleOpInterface inits only.
748/// This is a poor man's analysis that is sufficient to check the extractSliceOp
749/// the matches tensor.pad we want to hoist.
750/// In the future, it will be easier to ensure this with a matching symmetric
751/// tensor.unpad op.
752static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
753 Value expectedSource) {
754 LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp
755 << "\n");
756 LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n");
757 Value source = extractSliceOp.getSource();
758 LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
759 while (source && source != expectedSource) {
760 auto destOp =
761 dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
762 if (!destOp)
763 break;
764 LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
765 source = destOp.getDpsInitOperand(cast<OpResult>(Val&: source).getResultNumber())
766 ->get();
767 }
768 LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");
769 LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");
770 return source == expectedSource;
771}
772
773/// If the original consumer of `outerSliceOp` was a `forOp` (i.e. through an
774/// iter arg), propagate the `hoistedPackedTensor` value through the same iter
775/// arg.
776/// TODO: for multiple loops we need to track the use to the innermost loop.
777///
778/// Match:
779/// ```
780/// %outerSliceOp = tensor.extract_slice ..
781/// %f = scf.for ... iter_args(%arg0 = %outerSliceOp) {
782/// %hoistedPackedTensor = tensor.pad %arg0
783/// %1 = compute %hoistedPackedTensor
784/// %2 = tensor.extract_slice %1
785/// scf.yield %2
786/// }
787/// ```
788///
789/// and rewrite as:
790/// ```
791/// %outerSliceOp = tensor.extract_slice ..
792/// %hoistedPackedTensor = tensor.pad %outerSliceOp
793/// %f = scf.for ... iter_args(%arg0 = %hoistedPackedTensor) {
794/// %1 = compute %arg0
795/// scf.yield %1
796/// }
797/// %2 = tensor.extract_slice %forOp
798/// ```
799///
800/// Return null when no rewrite happened.
801static tensor::ExtractSliceOp
802padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
803 Value hoistedPackedTensor,
804 tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) {
805 LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n");
806 LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: "
807 << paddedValueBeforeHoisting << "\n");
808 OpOperand *pUse = nullptr;
809 for (OpOperand &use : outerSliceOp->getUses()) {
810 if (use.getOwner() == forOp) {
811 assert(!pUse && "Multiple slice uses in the for loop");
812 pUse = &use;
813 }
814 }
815 assert(pUse && "No slice use in the for loop");
816 OpBuilder::InsertionGuard g(rewriter);
817 rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
818
819 unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
820 auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
821 .getDefiningOp<tensor::ExtractSliceOp>();
822 if (!yieldingExtractSliceOp)
823 return tensor::ExtractSliceOp();
824
825 // Poor man's analysis sufficient to ensure extractSlice matches tensor.pad.
826 // In the future, it will be easier to ensure this with a matching symmetric
827 // tensor.unpad op.
828 if (!tracesBackToExpectedValue(yieldingExtractSliceOp,
829 paddedValueBeforeHoisting))
830 return tensor::ExtractSliceOp();
831
832 SmallVector<Value> initArgs = forOp.getInitArgs();
833 initArgs[iterArgNumber] = hoistedPackedTensor;
834 SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues());
835 yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
836
837 int64_t numOriginalForOpResults = initArgs.size();
838 LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
839 << "\n");
840 tensor::ExtractSliceOp extracted;
841 {
842 OpBuilder::InsertionGuard g(rewriter);
843 rewriter.setInsertionPointAfter(forOp);
844 extracted = rewriter.create<tensor::ExtractSliceOp>(
845 hoistedPackedTensor.getLoc(), hoistedPackedTensor,
846 outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
847 outerSliceOp.getMixedStrides());
848 rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
849 }
850 scf::ForOp newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
851 rewriter, initArgs, /*replaceInitOperandUsesInLoop=*/true,
852 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
853 return yieldOperands;
854 }));
855
856 LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()
857 << "\n");
858 LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
859 LLVM_DEBUG(DBGS() << "with result #"
860 << numOriginalForOpResults + iterArgNumber
861 << " of forOp, giving us: " << extracted << "\n");
862 rewriter.startOpModification(op: extracted);
863 extracted.getSourceMutable().assign(
864 newForOp.getResult(numOriginalForOpResults + iterArgNumber));
865 rewriter.finalizeOpModification(op: extracted);
866
867 LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
868 << "\n");
869 LLVM_DEBUG(DBGS() << "with region iter arg #"
870 << numOriginalForOpResults + iterArgNumber << "\n");
871 rewriter.replaceAllUsesWith(
872 paddedValueBeforeHoisting,
873 newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
874
875 return extracted;
876}
877
878/// Produce a tensor extracted from the packingResult. This can be used as a
879/// replacement for `opToHoist` in callers.
880static Value replaceByPackingResult(RewriterBase &rewriter,
881 const IRMapping &bvm,
882 tensor::PadOp opToHoist,
883 RankedTensorType transposedTensorType,
884 const HoistPaddingAnalysis &analysis,
885 const PackingResult &packingResult) {
886 // The replacement occurs under a single insertion point within the original
887 // loop, just before opToHoist.
888 OpBuilder::InsertionGuard g(rewriter);
889 rewriter.setInsertionPoint(opToHoist);
890
891 Location loc = opToHoist->getLoc();
892 RankedTensorType paddedTensorType = opToHoist.getResultType();
893 int paddedRank = paddedTensorType.getRank();
894
895 int64_t nPackedLoops = packingResult.clonedLoopIvs.size();
896 LLVM_DEBUG(DBGS() << "nPackedLoops: " << nPackedLoops << " loops\n");
897
898 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
899 ArrayRef<scf::ForOp> packingLoops = analysis.packingLoops;
900
901 Value hoistedPackedTensor;
902 SmallVector<Value> loopIterationCounts;
903 SmallVector<OpFoldResult> offsets(nPackedLoops + paddedRank,
904 rewriter.getIndexAttr(0));
905 if (nPackedLoops > 0) {
906 loopIterationCounts =
907 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
908 return buildLoopIterationCount(rewriter, outerLoop,
909 cast<scf::ForOp>(loop));
910 }));
911 // Assert all loop iteration counts can be computed.
912 if (llvm ::any_of(Range&: loopIterationCounts, P: [](Value v) { return !v; }))
913 llvm_unreachable("loop independence prerequisite not met");
914
915 // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0].
916 std::copy(first: loopIterationCounts.begin(), last: loopIterationCounts.end(),
917 result: offsets.begin());
918 hoistedPackedTensor =
919 scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front())
920 ->getResult(0);
921 } else {
922 // If no loops were created, this is just hoisting without packing.
923 hoistedPackedTensor = bvm.lookup(opToHoist.getResult());
924 }
925
926 LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n");
927
928 // If the consumer of `padOp` was a `forOp`, propagate through iter args.
929 scf::ForOp forOp = analysis.padConsumingForOp;
930 if (forOp) {
931 return padThroughLoopIterArg(rewriter, opToHoist, hoistedPackedTensor,
932 analysis.sliceOp, forOp);
933 }
934
935 // offsets = [maybe_leading_ivs, 0 .. 0].
936 // sizes = [1 .. 1, transposedShape] (defined above).
937 // strides = [1 .. 1] (defined above)
938 return rewriter.create<tensor::ExtractSliceOp>(
939 loc, transposedTensorType, hoistedPackedTensor, offsets,
940 packingResult.sizes, packingResult.strides);
941}
942
943FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
944 RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops,
945 ArrayRef<int64_t> transposeVector, tensor::PadOp &hoistedOp,
946 SmallVectorImpl<TransposeOp> &transposeOps) {
947 LLVM_DEBUG(DBGS() << "\n"; DBGS() << " Try to hoist " << *(opToHoist) << "\n";
948 DBGS() << " by " << numLoops << " loops\n");
949
950 HoistPaddingAnalysis analysis(opToHoist, numLoops);
951 analysis.enableHoistPadding(rewriter);
952 analysis.finalizeHoistPaddingAnalysis();
953 if (!analysis.isValid()) {
954 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
955 return failure();
956 }
957
958 /// Construct the packing loop nest.
959 IRMapping bvm;
960 FailureOr<PackingResult> packingResult = buildPackingLoopNestImpl(
961 rewriter, bvm, opToHoist, transposeVector, analysis);
962 if (failed(Result: packingResult)) {
963 LLVM_DEBUG(DBGS() << "--buildPackingLoopNestImpl failed -> Skip\n");
964 return failure();
965 }
966
967 if (!transposeVector.empty())
968 transposeOps.push_back(packingResult->maybeTransposeOp);
969
970 FailureOr<RankedTensorType> transposedTensorType =
971 tensor::computeTransposedType(rankedTensorType: opToHoist.getResultType(), transposeVector);
972 assert(succeeded(transposedTensorType) && "unexpected failure in type");
973
974 // Now the packed tensor is ready, replace the original padding op by a
975 // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
976 Value newResult =
977 replaceByPackingResult(rewriter, bvm, opToHoist, *transposedTensorType,
978 analysis, *packingResult);
979
980 Location loc = opToHoist->getLoc();
981 RankedTensorType paddedTensorType = opToHoist.getResultType();
982 if (!transposeVector.empty()) {
983 OpBuilder::InsertionGuard g(rewriter);
984 rewriter.setInsertionPointAfter(newResult.getDefiningOp());
985 // Transpose the packed tensor back to the original storage order.
986 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
987 loc, paddedTensorType.getShape(), paddedTensorType.getElementType());
988 TransposeOp unTransposeOp = rewriter.create<linalg::TransposeOp>(
989 loc, newResult, emptyTensor, transposeVector);
990 newResult = unTransposeOp.getResult()[0];
991 transposeOps.push_back(unTransposeOp);
992 }
993
994 LLVM_DEBUG(DBGS() << "newResult: " << newResult << "\n");
995 LLVM_DEBUG(
996 DBGS() << "After hoisting: "
997 << newResult.getDefiningOp()->getParentOfType<func::FuncOp>()
998 << "\n");
999
1000 // Make the newly cloned `opToHoist` available to the caller.
1001 hoistedOp = packingResult->hoistedPadOp;
1002
1003 LLVM_DEBUG(DBGS() << "--SUCCESS\n");
1004 return newResult;
1005}
1006
1007FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
1008 tensor::PadOp opToHoist, int64_t numLoops,
1009 ArrayRef<int64_t> transposeVector, tensor::PadOp &hoistedOp,
1010 SmallVectorImpl<TransposeOp> &transposeOps) {
1011 IRRewriter rewriter(opToHoist.getContext());
1012 return hoistPaddingOnTensors(rewriter, opToHoist, numLoops, transposeVector,
1013 hoistedOp, transposeOps);
1014}
1015

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp