1 | //===- HoistPadding.cpp - Hoisting for tensor::PadOp ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements functions concerned with hoisting padding operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Analysis/Presburger/IntegerRelation.h" |
14 | #include "mlir/Analysis/SliceAnalysis.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
20 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
21 | #include "mlir/Dialect/SCF/IR/SCF.h" |
22 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
24 | #include "mlir/IR/AsmState.h" |
25 | #include "mlir/IR/Dominance.h" |
26 | #include "mlir/IR/Matchers.h" |
27 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
28 | #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
29 | #include "mlir/Transforms/RegionUtils.h" |
30 | #include "llvm/Support/Debug.h" |
31 | |
32 | using llvm::dbgs; |
33 | |
34 | #define DEBUG_TYPE "hoist-padding" |
35 | |
36 | #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") |
37 | |
38 | using namespace mlir; |
39 | using namespace mlir::linalg; |
40 | using namespace mlir::linalg::detail; |
41 | |
42 | #ifndef NDEBUG |
43 | static bool debugPrintLoopInShortForm(Operation *op) { |
44 | AsmState state(op->getParentOfType<func::FuncOp>()); |
45 | (void)state; |
46 | if (auto forOp = dyn_cast<scf::ForOp>(op)) { |
47 | forOp.getInductionVar().printAsOperand(dbgs(), state); |
48 | dbgs() << " @ " << forOp.getOperation(); |
49 | return true; |
50 | } |
51 | return false; |
52 | } |
53 | #endif |
54 | |
55 | static void debugPrintBackwardSlice(SetVector<Operation *> &backwardSlice) { |
56 | LLVM_DEBUG(llvm::interleaveComma(backwardSlice, DBGS() << "--backwardSlice:" , |
57 | [](Operation *op) { |
58 | dbgs() << "\n" ; |
59 | DBGS() << "----" ; |
60 | if (debugPrintLoopInShortForm(op)) { |
61 | dbgs() << "\n" ; |
62 | return; |
63 | } |
64 | dbgs() << *op << "\n" ; |
65 | }); |
66 | DBGS() << "\n" ;); |
67 | } |
68 | |
69 | /// Return at most nLevels of immediately enclosing scf::ForOp loops. |
70 | /// Stops at the first parent that is not an scf::ForOp. |
71 | /// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm. |
72 | /// Control-flow and other containing ops with regions are not modeled atm. |
73 | static void |
74 | getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels, |
75 | SmallVector<scf::ForOp> &reverseEnclosingLoops) { |
76 | scf::ForOp outermostEnclosingForOp = nullptr; |
77 | Operation *nextEnclosingOp = padOp->getParentOp(); |
78 | while (nLevels-- > 0 && |
79 | (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) { |
80 | LLVM_DEBUG(DBGS() << "loops: " ; |
81 | debugPrintLoopInShortForm(outermostEnclosingForOp); |
82 | dbgs() << "\n" ); |
83 | reverseEnclosingLoops.push_back(outermostEnclosingForOp); |
84 | nextEnclosingOp = outermostEnclosingForOp->getParentOp(); |
85 | } |
86 | } |
87 | |
88 | /// Return at most nLevels of immediately enclosing scf::ForOp loops. |
89 | /// Stops at the first parent that is not an scf::ForOp. |
90 | /// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm. |
91 | /// Control-flow and other containing ops with regions are not modeled atm. |
92 | static void |
93 | getEnclosingLoopsUntil(tensor::PadOp padOp, scf::ForOp untilLoop, |
94 | SmallVector<scf::ForOp> &reverseEnclosingLoops) { |
95 | scf::ForOp outermostEnclosingForOp = nullptr; |
96 | Operation *nextEnclosingOp = padOp->getParentOp(); |
97 | while (outermostEnclosingForOp != untilLoop && |
98 | (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) { |
99 | LLVM_DEBUG(DBGS() << "loops: " ; |
100 | debugPrintLoopInShortForm(outermostEnclosingForOp); |
101 | dbgs() << "\n" ); |
102 | reverseEnclosingLoops.push_back(outermostEnclosingForOp); |
103 | nextEnclosingOp = outermostEnclosingForOp->getParentOp(); |
104 | } |
105 | } |
106 | |
107 | // Get all the ops in the backwards slice starting from `padOp` and that |
108 | // are dominated by the outermost enclosing loop. |
109 | // This also requires tracking ops defining values used in the region but |
110 | // defined above. |
111 | static void computeBackwardSlice(tensor::PadOp padOp, |
112 | scf::ForOp outermostEnclosingForOp, |
113 | SetVector<Operation *> &backwardSlice) { |
114 | DominanceInfo domInfo(outermostEnclosingForOp); |
115 | BackwardSliceOptions sliceOptions; |
116 | sliceOptions.filter = [&](Operation *op) { |
117 | return domInfo.dominates(outermostEnclosingForOp, op) && |
118 | !padOp->isProperAncestor(op); |
119 | }; |
120 | sliceOptions.inclusive = true; |
121 | |
122 | // First, add the ops required to compute the region to the backwardSlice. |
123 | SetVector<Value> valuesDefinedAbove; |
124 | getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(), |
125 | valuesDefinedAbove); |
126 | for (Value v : valuesDefinedAbove) { |
127 | getBackwardSlice(root: v, backwardSlice: &backwardSlice, options: sliceOptions); |
128 | } |
129 | // Then, add the backward slice from padOp itself. |
130 | getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions); |
131 | } |
132 | |
133 | //===----------------------------------------------------------------------===// |
134 | // HoistPaddingAnalysis Implementation. |
135 | //===----------------------------------------------------------------------===// |
136 | |
137 | namespace { |
138 | /// Analysis class to support tensor::PadOp hoisting across multiple enclosing |
139 | /// loops. The failure conditions are: |
140 | /// 1. Pad op has a use that is not an input of a LinalgOp. |
141 | /// 2. Pad op does not have a constant padding value. |
142 | /// 3. There is no immediately enclosing scf::ForOp. |
143 | /// 4. The backward slice from the pad op to the scf::ForOp to hoist above |
144 | /// contains an unknown op with non index type operands, a region, or a |
145 | /// memory effect. |
146 | /// 5. The backward slice from the pad op to the scf::ForOp to hoist above is |
147 | /// empty. |
148 | /// 6. The source tensor of pad op is not defined by an extract slice op. |
149 | /// 7. The source tensor of the extract slice op is not defined outside of |
150 | /// the outermost enclosing scf::ForOp. |
151 | /// 8. There is no enclosing scf::ForOp that indexes the padded data. |
152 | /// Other cases succeed and will trigger hoisting of the pad op. |
153 | struct HoistPaddingAnalysis { |
154 | HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops); |
155 | HoistPaddingAnalysis(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp); |
156 | |
157 | bool isValid() { return valid.has_value() && valid.value(); } |
158 | bool isInvalid() { return valid.has_value() && !valid.value(); } |
159 | |
160 | /// Footprint of the hoistedPackedTensor, computed from the packingLoops. |
161 | SmallVector<Value> getHoistedPackedTensorSizes(RewriterBase &rewriter, |
162 | Location loc) const; |
163 | |
164 | /// Performs optional hoisting to enable hoist padding to occur. This may be |
165 | /// necessary when `sliceOp` is not defined outside of the outermost enclosing |
166 | /// loop we want to hoist above. |
167 | /// |
168 | /// Example: |
169 | /// ``` |
170 | /// %source = linalg.fill(%cst, %arg0) |
171 | /// // %source is available for packing here! |
172 | /// scf.for %i |
173 | /// scf.for %j |
174 | /// scf.for %k |
175 | /// %slice = tensor.extract_slice %source [%i, %j] |
176 | /// %padded_slice = tensor.pad %slice |
177 | /// ``` |
178 | void enableHoistPadding(RewriterBase &rewriter); |
179 | |
180 | /// Common analysis builder to finalize the construction of the analysis once |
181 | /// optional `enableHoistPadding` has run. |
182 | /// `reverseEnclosingLoops.back()` is the loop to hoist above. |
183 | void finalizeHoistPaddingAnalysis(); |
184 | |
185 | private: |
186 | /// Encodes whether the analysis is valid and hoisting can proceed. |
187 | std::optional<bool> valid; |
188 | |
189 | /// The padOp to hoist. |
190 | tensor::PadOp opToHoist; |
191 | |
192 | /// Immediately enclosing loops considered for hoisting padding. |
193 | SmallVector<scf::ForOp> reverseEnclosingLoops; |
194 | |
195 | /// Drop any non-index dependencies of `padOp` and `sliceOp` from |
196 | /// `backwardSlice`. The method follows the use-def chains of the index |
197 | /// operands consumed by `padOp` and `sliceOp` and drops the operations |
198 | /// not part of this index computation. Afterwards, the filtered |
199 | /// `backwardSlice` contains only the loops whose induction variable is |
200 | /// used, directly or indirectly, to index the padded tensor. The method |
201 | /// returns failure if the filtered backward slice contains an unexpected |
202 | /// operation. |
203 | /// |
204 | /// Example: |
205 | /// ``` |
206 | /// %source = linalg.fill(%cst, %arg0) |
207 | /// scf.for %i |
208 | /// %unrelated = linalg.fill(%cst, %arg1) // not used to index |
209 | /// %source! scf.for %j (%arg2 = %unrelated) |
210 | /// scf.for %k // not used to index |
211 | /// %source! |
212 | /// %ubi = affine.min #map(%i) |
213 | /// %ubj = affine.min #map(%j) |
214 | /// %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj] |
215 | /// %padded_slice = tensor.pad %slice |
216 | /// ``` |
217 | /// dropNonIndexDependencies(%padded_slice, %slice) |
218 | /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice. |
219 | LogicalResult dropNonIndexDependencies(); |
220 | |
221 | public: |
222 | /// The outermost loop, determined by `nLevels` above which `padOp` will |
223 | /// be hoisted. |
224 | scf::ForOp outermostEnclosingForOp; |
225 | |
226 | /// Backward slice rooted at `padOp` and nested under |
227 | /// `outermostEnclosingForOp`. |
228 | SetVector<Operation *> backwardSlice; |
229 | |
230 | /// The scf::ForOp immediately enclosing `padOp` such that: |
231 | /// 1. they are nested under `outermostEnclosingForOp` (inclusive) |
232 | /// 2. whose induction variable is used, directly or indirectly, in the |
233 | /// computation of `padOp`. |
234 | /// The span of these loops determines the footprint of the packed tensor. |
235 | SmallVector<scf::ForOp> packingLoops; |
236 | |
237 | /// The ExtractSliceOp that feeds the PadOp we want to hoist. |
238 | tensor::ExtractSliceOp sliceOp; |
239 | |
240 | /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`. |
241 | scf::ForOp padConsumingForOp; |
242 | }; |
243 | |
244 | } // namespace |
245 | |
246 | HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops) |
247 | : valid(std::nullopt), opToHoist(padOp) { |
248 | // Get at most `numLoops` of immediately enclosing loops. |
249 | getAtMostNEnclosingLoops(opToHoist, numLoops, reverseEnclosingLoops); |
250 | if (reverseEnclosingLoops.empty()) { |
251 | LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n" ); |
252 | valid = false; |
253 | return; |
254 | } |
255 | outermostEnclosingForOp = reverseEnclosingLoops.back(); |
256 | sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>(); |
257 | if (!sliceOp) { |
258 | LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n" ); |
259 | valid = false; |
260 | return; |
261 | } |
262 | } |
263 | |
264 | HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, |
265 | scf::ForOp outermostEnclosingForOp) |
266 | : valid(std::nullopt), opToHoist(padOp) { |
267 | // Get enclosing loops until outermostEnclosingForOp. |
268 | getEnclosingLoopsUntil(opToHoist, outermostEnclosingForOp, |
269 | reverseEnclosingLoops); |
270 | if (reverseEnclosingLoops.empty()) { |
271 | LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n" ); |
272 | valid = false; |
273 | return; |
274 | } |
275 | this->outermostEnclosingForOp = reverseEnclosingLoops.back(); |
276 | if (this->outermostEnclosingForOp != outermostEnclosingForOp) { |
277 | LLVM_DEBUG(DBGS() << "--Unexpected outermost enclosing loop -> Skip\n" ); |
278 | valid = false; |
279 | return; |
280 | } |
281 | sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>(); |
282 | if (!sliceOp) { |
283 | LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n" ); |
284 | valid = false; |
285 | return; |
286 | } |
287 | } |
288 | |
289 | void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) { |
290 | if (isInvalid()) |
291 | return; |
292 | // If the padded data is not yet available before entering the outermost |
293 | // enclosing loop, try to apply hoisting on this outermost loop. |
294 | // TODO: we may want finer-grained hoisting of only that particular `sliceOp`. |
295 | if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) { |
296 | outermostEnclosingForOp = cast<scf::ForOp>( |
297 | hoistLoopInvariantSubsets(rewriter, outermostEnclosingForOp)); |
298 | } |
299 | } |
300 | |
301 | void HoistPaddingAnalysis::finalizeHoistPaddingAnalysis() { |
302 | if (isInvalid()) |
303 | return; |
304 | |
305 | if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) { |
306 | LLVM_DEBUG(DBGS() << "--outermostEnclosingForOp:\n" |
307 | << outermostEnclosingForOp << "\n" |
308 | << "--sliceOp: " << sliceOp << "\n" |
309 | << "--sliceOp.getSource(): " << sliceOp.getSource() |
310 | << "\n" ); |
311 | LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n" ); |
312 | valid = false; |
313 | return; |
314 | } |
315 | if (sliceOp->hasOneUse()) { |
316 | padConsumingForOp = dyn_cast<scf::ForOp>(*(sliceOp->getUsers().begin())); |
317 | } |
318 | |
319 | // Check the region of `padOp` depends on a constant only. Adding hoisting |
320 | // support for arbitrary padding regions would require cloning all |
321 | // dependencies captured by the padding region. |
322 | Value paddingValue = opToHoist.getConstantPaddingValue(); |
323 | if (!paddingValue || |
324 | !isa_and_nonnull<arith::ConstantOp>(paddingValue.getDefiningOp())) { |
325 | LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> Skip\n" ); |
326 | valid = false; |
327 | return; |
328 | } |
329 | |
330 | computeBackwardSlice(opToHoist, outermostEnclosingForOp, backwardSlice); |
331 | if (backwardSlice.size() <= 1) { |
332 | valid = false; |
333 | return; |
334 | } |
335 | |
336 | debugPrintBackwardSlice(backwardSlice); |
337 | // Remove all ops in the backward slice that are not used to index |
338 | // the padded tensor. In particular, keep `padOp`, `sliceOp`, and |
339 | // the loop and affine operations used for the index computation. |
340 | if (failed(result: dropNonIndexDependencies())) { |
341 | LLVM_DEBUG(DBGS() << "--Cannot dropNonIndexDependencies -> Skip\n" ); |
342 | valid = false; |
343 | return; |
344 | } |
345 | debugPrintBackwardSlice(backwardSlice); |
346 | |
347 | // Add only the loops part of the filtered `backwardSlice` to the |
348 | // packing loops. All other loops are not used to index the padded |
349 | // data and consequently access the same data in every loop |
350 | // iteration. Adding them to the packing loops would increase the |
351 | // cache footprint of the packed data by storing the same data |
352 | // multiple times. |
353 | for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops)) |
354 | if (backwardSlice.contains(forOp)) |
355 | packingLoops.push_back(forOp); |
356 | |
357 | // TODO: for multiple loops we need to track the use to the innermost loop. |
358 | if (packingLoops.size() > 1 && padConsumingForOp) { |
359 | LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> " |
360 | "Downgrade to 1 loop\n" ); |
361 | packingLoops.resize(1); |
362 | } |
363 | |
364 | // Note: at this point, packing loops may be empty but we would still like |
365 | // to hoist the padding if so specified. |
366 | |
367 | // The analysis is valid and hoisting can occur. |
368 | valid = true; |
369 | } |
370 | |
371 | LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() { |
372 | // Set of all values used for index computation. |
373 | SetVector<Value> indexEdges; |
374 | |
375 | // Add all index operands of `operation` to `indexEdges`. An index operand |
376 | // is an operand of type index. |
377 | auto addIndexOperandsToIndexEdges = [&](Operation *operation) { |
378 | for (Value operand : operation->getOperands()) |
379 | if (operand.getType().isIndex()) |
380 | indexEdges.insert(X: operand); |
381 | }; |
382 | |
383 | // Check if any operation result is contained in `indexEdges`. |
384 | auto hasIndexResult = [&](Operation *operation) { |
385 | return llvm::any_of(Range: operation->getResults(), P: [&](Value result) { |
386 | return indexEdges.contains(key: result); |
387 | }); |
388 | }; |
389 | |
390 | // Starting from `opToHoist` and `sliceOp` walk the use-def edges of index |
391 | // type in `backwardSlice`. Add the index operands of an operation to |
392 | // `indexEdges` and remove all operations from `backwardSlice` that are not |
393 | // part of the index computation. |
394 | // |
395 | // Example: |
396 | // ``` |
397 | // %source = linalg.fill(%cst, %arg0) |
398 | // scf.for %i |
399 | // %unrelated = linalg.fill(%cst, %arg1) // not used to index %source! |
400 | // scf.for %j (%arg2 = %unrelated) |
401 | // scf.for %k // not used to index %source! |
402 | // %ubi = affine.min #map(%i) |
403 | // %ubj = affine.min #map(%j) |
404 | // %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj] |
405 | // %padded_slice = tensor.pad %slice |
406 | // ``` |
407 | // After iterating `backwardSlice` we obtain: |
408 | // indexEdges = [%i, %j, %ubi, %ubj] |
409 | // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k] |
410 | SetVector<Operation *> operationsToRemove; |
411 | for (Operation *op : llvm::reverse(C&: backwardSlice)) { |
412 | // Add the index operands of `opToHoist` and `sliceOp` to start the |
413 | // exploration of the index computation. |
414 | if (op == opToHoist || op == sliceOp) { |
415 | addIndexOperandsToIndexEdges(op); |
416 | continue; |
417 | } |
418 | // Add the index operands of the loop if its induction variable is |
419 | // used for index computation. |
420 | if (auto forOp = dyn_cast<scf::ForOp>(op)) { |
421 | if (!hasIndexResult(op) && indexEdges.contains(key: forOp.getInductionVar())) { |
422 | addIndexOperandsToIndexEdges(op); |
423 | continue; |
424 | } |
425 | } |
426 | // Add the index operands of all other operations if at least one result |
427 | // is used for index computation. |
428 | if (hasIndexResult(op)) { |
429 | addIndexOperandsToIndexEdges(op); |
430 | // Check the operands of the remaining operations all have index type. |
431 | if (llvm::any_of(Range: op->getOperandTypes(), |
432 | P: [](Type type) { return !type.isIndex(); })) { |
433 | LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: " |
434 | << op << " -> Skip\n" ); |
435 | return failure(); |
436 | } |
437 | // Check the remaining operations do not have regions or memory effects. |
438 | auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op); |
439 | bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect(); |
440 | if (hasMemoryEffect || op->getNumRegions() != 0) { |
441 | LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: " |
442 | << op << " -> Skip\n" ); |
443 | return failure(); |
444 | } |
445 | continue; |
446 | } |
447 | // Remove all other operations not used by the index computation. An |
448 | // exception are constant operations that may be used by `opToHoist`. |
449 | if (!isa<arith::ConstantOp>(op)) |
450 | operationsToRemove.insert(X: op); |
451 | } |
452 | backwardSlice.set_subtract(operationsToRemove); |
453 | return success(); |
454 | } |
455 | |
456 | SmallVector<Value> |
457 | HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter, |
458 | Location loc) const { |
459 | SmallVector<Value> dynamicTensorSizes; |
460 | |
461 | // Upper bound the packing loop lengths to size the packed tensor. Taking |
462 | // upper bounds can make the sizes of the packed tensor independent of the |
463 | // enclosing loops. This independence is a prerequisite for reusing the same |
464 | // buffer for all enclosing loop iterations and hoisting its allocation out |
465 | // of the enclosing loops. |
466 | for (auto forOp : packingLoops) { |
467 | // Compute an upper bound `ubVal` for the upper bound of `forOp`. |
468 | FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound( |
469 | rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(), |
470 | /*stopCondition=*/ |
471 | [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { |
472 | if (v == forOp.getUpperBound()) |
473 | return false; |
474 | // Compute a bound that is independent of any affine op results. |
475 | Operation *op = v.getDefiningOp(); |
476 | if (!op) |
477 | return true; |
478 | return !isa<affine::AffineMinOp, affine::AffineMaxOp, |
479 | affine::AffineApplyOp>(op); |
480 | }, |
481 | /*closedUB=*/true); |
482 | assert(succeeded(loopUb) && "could not get upper bound" ); |
483 | Value ubVal = getValueOrCreateConstantIndexOp(rewriter, loc, *loopUb); |
484 | |
485 | // Compute the maximal packing loop length as (ub - lb).ceilDiv(step) and |
486 | // store the result to `dynamicTensorSizes`. |
487 | // TODO: instead of using the lower bound of `forOp` directly, implement a |
488 | // lower bound computation similar to the upper bound computation. |
489 | AffineExpr lb, ub, step; |
490 | bindDims(rewriter.getContext(), lb, ub); |
491 | bindSymbols(rewriter.getContext(), step); |
492 | Value res = rewriter.createOrFold<affine::AffineApplyOp>( |
493 | loc, (ub - lb).ceilDiv(step), |
494 | ValueRange{forOp.getLowerBound(), ubVal, |
495 | cast<scf::ForOp>(forOp).getStep()}); |
496 | dynamicTensorSizes.push_back(res); |
497 | } |
498 | |
499 | return dynamicTensorSizes; |
500 | } |
501 | |
502 | static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) { |
503 | return outer.isDefinedOutsideOfLoop(v) || matchPattern(value: v, pattern: m_Constant()); |
504 | } |
505 | |
506 | //===----------------------------------------------------------------------===// |
507 | // buildPackingLoopNest Implementation. |
508 | //===----------------------------------------------------------------------===// |
509 | |
510 | /// Return the current iteration number in the loop (iv - lb).ceilDiv(step). |
511 | /// The returned Value is guaranteed not to depend on any loop comprised in |
512 | /// [`outer`, `forOp`]. |
513 | /// Return null if such a loop-independent quantity cannot be computed. |
514 | static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer, |
515 | scf::ForOp forOp) { |
516 | MLIRContext *ctx = forOp->getContext(); |
517 | AffineExpr iv, lb, step; |
518 | bindDims(ctx, exprs&: iv, exprs&: lb); |
519 | bindSymbols(ctx, exprs&: step); |
520 | if (!isDefinedOutsideOrConstant(outer, forOp.getLowerBound()) || |
521 | !isDefinedOutsideOrConstant(outer, forOp.getStep())) |
522 | return Value(); |
523 | Value ivVal = forOp.getInductionVar(), lbVal = forOp.getLowerBound(), |
524 | stepVal = forOp.getStep(); |
525 | auto loc = forOp->getLoc(); |
526 | return rewriter.createOrFold<affine::AffineApplyOp>( |
527 | loc, (iv - lb).ceilDiv(other: step), ValueRange{ivVal, lbVal, stepVal}); |
528 | } |
529 | |
530 | // Build a packing loop nest by iteratively traversing the backward slice and |
531 | // clone the operations, iteratively stepping into the loops that we encounter. |
532 | // The implementation proceeds in a stack-like fashion: |
533 | // 1. Iteratively clone and step into the loops, pushing the |
534 | // `hoistedPackedTensor` |
535 | // deeper in the stack. |
536 | // 2. At the innermost loop level, create a GenericOp if `transposeVector` is |
537 | // non-empty. |
538 | // 3. At the innermost loop level, create a InsertSliceOp. |
539 | // 4. Iteratively pop and yield the result of the InsertSliceOp across the |
540 | // cloned loops. |
541 | static FailureOr<PackingResult> buildPackingLoopNestImpl( |
542 | RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist, |
543 | ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType, |
544 | tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) { |
545 | SmallVector<OpFoldResult> offsets, sizes, strides; |
546 | SmallVector<Value> clonedLoopIvs, leadingHoistedPackedTensorIndexings; |
547 | |
548 | scf::ForOp outerLoop = analysis.outermostEnclosingForOp; |
549 | |
550 | Location loc = opToHoist->getLoc(); |
551 | RankedTensorType paddedTensorType = opToHoist.getResultType(); |
552 | int paddedRank = paddedTensorType.getRank(); |
553 | |
554 | // Step 0. Populate bvm with opToHoist.getSource if relevant. |
555 | BlockArgument bbArg = dyn_cast<BlockArgument>(opToHoist.getSource()); |
556 | while (bbArg) { |
557 | auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp()); |
558 | if (!forOp) |
559 | break; |
560 | if (forOp != outerLoop && !outerLoop->isAncestor(forOp)) |
561 | break; |
562 | OpOperand &operand = *forOp.getTiedLoopInit(bbArg); |
563 | bvm.map(from: bbArg, to: operand.get()); |
564 | bbArg = dyn_cast<BlockArgument>(Val: operand.get()); |
565 | } |
566 | |
567 | // Step 1. iteratively clone loops and push `hoistedPackedTensor`. |
568 | Value hoistedPackedTensor = emptyOp.getResult(); |
569 | OpBuilder::InsertionGuard g(rewriter); |
570 | for (Operation *op : analysis.backwardSlice) { |
571 | // Specifically sit out in the extract_slice(hoistedPackedTensor) case: this |
572 | // is the piece we seek to replace. |
573 | if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { |
574 | if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) { |
575 | LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n" ); |
576 | continue; |
577 | } |
578 | } |
579 | |
580 | // Clone all operations except loops which require special handling. |
581 | auto forOp = dyn_cast<scf::ForOp>(op); |
582 | if (!forOp) { |
583 | // We are at the right insertion point within the loop nest. |
584 | rewriter.clone(op&: *op, mapper&: bvm); |
585 | continue; |
586 | } |
587 | |
588 | // Create a packing loop that takes `hoistedPackedTensor` as iteration |
589 | // argument. |
590 | auto clonedForOp = rewriter.create<scf::ForOp>( |
591 | loc, bvm.lookupOrDefault(forOp.getLowerBound()), |
592 | bvm.lookupOrDefault(forOp.getUpperBound()), |
593 | bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor); |
594 | |
595 | // Map the induction var, region args and results to the `clonedForOp`. |
596 | bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); |
597 | bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs()); |
598 | bvm.map(forOp.getResults(), clonedForOp.getResults()); |
599 | assert(clonedForOp->getNumRegions() == 1); |
600 | clonedLoopIvs.push_back(Elt: clonedForOp.getInductionVar()); |
601 | |
602 | // Do not insert guard here, we get deeper into the loop nest. |
603 | rewriter.setInsertionPointToStart(&clonedForOp->getRegion(0).front()); |
604 | Value loopIndependentIterationCount = |
605 | buildLoopIterationCount(rewriter, outerLoop, clonedForOp); |
606 | |
607 | // Assert the loop-independent iteration count can be computed. |
608 | if (!loopIndependentIterationCount) |
609 | llvm_unreachable("loop independence prerequisite not met" ); |
610 | leadingHoistedPackedTensorIndexings.push_back( |
611 | Elt: loopIndependentIterationCount); |
612 | hoistedPackedTensor = clonedForOp.getRegionIterArgs().front(); |
613 | } |
614 | |
615 | // Step 2. Construct offsets, sizes and strides for the innermost level of the |
616 | // packing loop. |
617 | int64_t nPackedLoops = clonedLoopIvs.size(); |
618 | // offsets = [clonedLoopIvs, 0 .. 0]. |
619 | offsets = |
620 | SmallVector<OpFoldResult>{leadingHoistedPackedTensorIndexings.begin(), |
621 | leadingHoistedPackedTensorIndexings.end()}; |
622 | offsets.append(paddedRank, rewriter.getIndexAttr(0)); |
623 | // sizes = [1 .. 1, transposedShape]. |
624 | sizes = SmallVector<OpFoldResult>(nPackedLoops, rewriter.getIndexAttr(1)); |
625 | for (int64_t sz : transposedTensorType.getShape()) { |
626 | // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. |
627 | if (ShapedType::isDynamic(sz)) |
628 | return failure(); |
629 | sizes.push_back(rewriter.getIndexAttr(sz)); |
630 | } |
631 | // strides = [1 .. 1]. |
632 | strides = SmallVector<OpFoldResult>(nPackedLoops + paddedRank, |
633 | rewriter.getIndexAttr(1)); |
634 | |
635 | // Step 3. Optionally transpose the padded tensor. |
636 | GenericOp maybeTransposeOp; |
637 | Value paddedTensor = bvm.lookup(opToHoist.getResult()); |
638 | if (!transposeVector.empty()) { |
639 | Value outputTensor = rewriter.create<tensor::ExtractSliceOp>( |
640 | loc, transposedTensorType, hoistedPackedTensor, offsets, sizes, |
641 | strides); |
642 | maybeTransposeOp = makeTransposeOp(rewriter, loc, paddedTensor, |
643 | outputTensor, transposeVector); |
644 | paddedTensor = maybeTransposeOp.getResult(0); |
645 | } |
646 | |
647 | // Innermost tensor.insert_slice and yields are optional / need loops. |
648 | if (nPackedLoops > 0) { |
649 | // Step 4. Create InsertSliceOp at the innermost loop level, inserting an |
650 | // optionally transposed padded slice into the packed tensor. |
651 | Value inserted = rewriter.create<tensor::InsertSliceOp>( |
652 | loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides); |
653 | |
654 | // Step 5. Iteratively pop the stack and propagate the yield. |
655 | Value valueToYield = inserted; |
656 | for (Value iv : llvm::reverse(C&: clonedLoopIvs)) { |
657 | auto forOp = scf::getForInductionVarOwner(iv); |
658 | rewriter.setInsertionPointToEnd(&forOp.getRegion().front()); |
659 | rewriter.create<scf::YieldOp>(loc, valueToYield); |
660 | valueToYield = forOp.getResult(0); |
661 | } |
662 | } |
663 | |
664 | return PackingResult{ |
665 | offsets, |
666 | sizes, |
667 | strides, |
668 | clonedLoopIvs, |
669 | leadingHoistedPackedTensorIndexings, |
670 | maybeTransposeOp, |
671 | cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp())}; |
672 | } |
673 | |
674 | /// Build the packing loop nest required to hoist `opToHoist` above |
675 | /// `outermostEnclosingForOp`. |
676 | /// The loop nest is built just before `outermostEnclosingForOp`. |
677 | static FailureOr<PackingResult> buildPackingLoopNestImpl( |
678 | RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist, |
679 | ArrayRef<int64_t> transposeVector, const HoistPaddingAnalysis &analysis) { |
680 | // Update actual number of loops, which may be smaller. |
681 | int nPackedLoops = analysis.packingLoops.size(); |
682 | LLVM_DEBUG(DBGS() << "\n" ; |
683 | DBGS() << "Func:\n" |
684 | << *opToHoist->getParentOfType<func::FuncOp>() << "\n" ; |
685 | DBGS() << "Start hoisting above " << nPackedLoops << " loops\n" ); |
686 | |
687 | Location loc = opToHoist->getLoc(); |
688 | RankedTensorType paddedTensorType = opToHoist.getResultType(); |
689 | |
690 | // Compute the type of the transposed padded tensor. |
691 | FailureOr<RankedTensorType> transposedTensorType = |
692 | tensor::computeTransposedType(paddedTensorType, transposeVector); |
693 | if (failed(transposedTensorType)) { |
694 | LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n" ); |
695 | return failure(); |
696 | } |
697 | |
698 | // Create the packed tensor<?x?x..? x transposedShape>. |
699 | SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic); |
700 | // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. |
701 | llvm::append_range(packedShape, transposedTensorType->getShape()); |
702 | auto hoistedPackedTensorType = RankedTensorType::get( |
703 | packedShape, transposedTensorType->getElementType()); |
704 | |
705 | // Set the insertion point right before the outer loop and start packing. |
706 | scf::ForOp outerLoop = analysis.outermostEnclosingForOp; |
707 | OpBuilder::InsertionGuard g(rewriter); |
708 | rewriter.setInsertionPoint(outerLoop); |
709 | SmallVector<Value> dynamicTensorSizes = |
710 | analysis.getHoistedPackedTensorSizes(rewriter, loc); |
711 | auto emptyOp = rewriter.create<tensor::EmptyOp>( |
712 | loc, hoistedPackedTensorType.getShape(), |
713 | hoistedPackedTensorType.getElementType(), dynamicTensorSizes); |
714 | |
715 | return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector, |
716 | *transposedTensorType, emptyOp, analysis); |
717 | } |
718 | |
719 | /// Build the packing loop nest required to hoist `opToHoist` above |
720 | /// `outermostEnclosingForOp`. |
721 | /// The loop nest is built just before `outermostEnclosingForOp`. |
722 | FailureOr<PackingResult> mlir::linalg::detail::buildPackingLoopNest( |
723 | RewriterBase &rewriter, tensor::PadOp opToHoist, |
724 | scf::ForOp outermostEnclosingForOp, ArrayRef<int64_t> transposeVector) { |
725 | HoistPaddingAnalysis analysis(opToHoist, outermostEnclosingForOp); |
726 | analysis.enableHoistPadding(rewriter); |
727 | analysis.finalizeHoistPaddingAnalysis(); |
728 | if (!analysis.isValid()) { |
729 | LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n" ); |
730 | return failure(); |
731 | } |
732 | IRMapping bvm; |
733 | return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector, |
734 | analysis); |
735 | } |
736 | |
737 | //===----------------------------------------------------------------------===// |
738 | // hoistPaddingOnTensors Implementation. |
739 | //===----------------------------------------------------------------------===// |
740 | |
741 | /// Return true if we can walk back the use-def chain from `extractSliceOp` to |
742 | /// expectedSource going through DestinationStyleOpInterface inits only. |
743 | /// This is a poor man's analysis that is sufficient to check the extractSliceOp |
744 | /// the matches tensor.pad we want to hoist. |
745 | /// In the future, it will be easier to ensure this with a matching symmetric |
746 | /// tensor.unpad op. |
747 | static bool (tensor::ExtractSliceOp , |
748 | Value expectedSource) { |
749 | LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp |
750 | << "\n" ); |
751 | LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n" ); |
752 | Value source = extractSliceOp.getSource(); |
753 | LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n" ); |
754 | while (source && source != expectedSource) { |
755 | auto destOp = |
756 | dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp()); |
757 | if (!destOp) |
758 | break; |
759 | LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n" ); |
760 | source = destOp.getDpsInitOperand(cast<OpResult>(Val&: source).getResultNumber()) |
761 | ->get(); |
762 | } |
763 | LLVM_DEBUG(DBGS() << "--final source: " << source << "\n" ); |
764 | LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n" ); |
765 | return source == expectedSource; |
766 | } |
767 | |
768 | /// If the original consumer of `outerSliceOp` was a `forOp` (i.e. through an |
769 | /// iter arg), propagate the `hoistedPackedTensor` value through the same iter |
770 | /// arg. |
771 | /// TODO: for multiple loops we need to track the use to the innermost loop. |
772 | /// |
773 | /// Match: |
774 | /// ``` |
775 | /// %outerSliceOp = tensor.extract_slice .. |
776 | /// %f = scf.for ... iter_args(%arg0 = %outerSliceOp) { |
777 | /// %hoistedPackedTensor = tensor.pad %arg0 |
778 | /// %1 = compute %hoistedPackedTensor |
779 | /// %2 = tensor.extract_slice %1 |
780 | /// scf.yield %2 |
781 | /// } |
782 | /// ``` |
783 | /// |
784 | /// and rewrite as: |
785 | /// ``` |
786 | /// %outerSliceOp = tensor.extract_slice .. |
787 | /// %hoistedPackedTensor = tensor.pad %outerSliceOp |
788 | /// %f = scf.for ... iter_args(%arg0 = %hoistedPackedTensor) { |
789 | /// %1 = compute %arg0 |
790 | /// scf.yield %1 |
791 | /// } |
792 | /// %2 = tensor.extract_slice %forOp |
793 | /// ``` |
794 | /// |
795 | /// Return null when no rewrite happened. |
796 | static tensor::ExtractSliceOp |
797 | (RewriterBase &rewriter, Value paddedValueBeforeHoisting, |
798 | Value hoistedPackedTensor, |
799 | tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) { |
800 | LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n" ); |
801 | LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: " |
802 | << paddedValueBeforeHoisting << "\n" ); |
803 | OpOperand *pUse = nullptr; |
804 | for (OpOperand &use : outerSliceOp->getUses()) { |
805 | if (use.getOwner() == forOp) { |
806 | assert(!pUse && "Multiple slice uses in the for loop" ); |
807 | pUse = &use; |
808 | } |
809 | } |
810 | assert(pUse && "No slice use in the for loop" ); |
811 | OpBuilder::InsertionGuard g(rewriter); |
812 | rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp()); |
813 | |
814 | unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber(); |
815 | auto = forOp.getYieldedValues()[iterArgNumber] |
816 | .getDefiningOp<tensor::ExtractSliceOp>(); |
817 | if (!yieldingExtractSliceOp) |
818 | return tensor::ExtractSliceOp(); |
819 | |
820 | // Poor man's analysis sufficient to ensure extractSlice matches tensor.pad. |
821 | // In the future, it will be easier to ensure this with a matching symmetric |
822 | // tensor.unpad op. |
823 | if (!tracesBackToExpectedValue(yieldingExtractSliceOp, |
824 | paddedValueBeforeHoisting)) |
825 | return tensor::ExtractSliceOp(); |
826 | |
827 | SmallVector<Value> initArgs = forOp.getInitArgs(); |
828 | initArgs[iterArgNumber] = hoistedPackedTensor; |
829 | SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues()); |
830 | yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource(); |
831 | |
832 | int64_t numOriginalForOpResults = initArgs.size(); |
833 | LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults |
834 | << "\n" ); |
835 | tensor::ExtractSliceOp ; |
836 | { |
837 | OpBuilder::InsertionGuard g(rewriter); |
838 | rewriter.setInsertionPointAfter(forOp); |
839 | extracted = rewriter.create<tensor::ExtractSliceOp>( |
840 | hoistedPackedTensor.getLoc(), hoistedPackedTensor, |
841 | outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(), |
842 | outerSliceOp.getMixedStrides()); |
843 | rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted); |
844 | } |
845 | scf::ForOp newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields( |
846 | rewriter, initArgs, /*replaceInitOperandUsesInLoop=*/true, |
847 | [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) { |
848 | return yieldOperands; |
849 | })); |
850 | |
851 | LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults() |
852 | << "\n" ); |
853 | LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n" ); |
854 | LLVM_DEBUG(DBGS() << "with result #" |
855 | << numOriginalForOpResults + iterArgNumber |
856 | << " of forOp, giving us: " << extracted << "\n" ); |
857 | rewriter.startOpModification(op: extracted); |
858 | extracted.getSourceMutable().assign( |
859 | newForOp.getResult(numOriginalForOpResults + iterArgNumber)); |
860 | rewriter.finalizeOpModification(op: extracted); |
861 | |
862 | LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting |
863 | << "\n" ); |
864 | LLVM_DEBUG(DBGS() << "with region iter arg #" |
865 | << numOriginalForOpResults + iterArgNumber << "\n" ); |
866 | rewriter.replaceAllUsesWith( |
867 | paddedValueBeforeHoisting, |
868 | newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber)); |
869 | |
870 | return extracted; |
871 | } |
872 | |
873 | /// Produce a tensor extracted from the packingResult. This can be used as a |
874 | /// replacement for `opToHoist` in callers. |
875 | static Value replaceByPackingResult(RewriterBase &rewriter, |
876 | const IRMapping &bvm, |
877 | tensor::PadOp opToHoist, |
878 | RankedTensorType transposedTensorType, |
879 | const HoistPaddingAnalysis &analysis, |
880 | const PackingResult &packingResult) { |
881 | // The replacement occurs under a single insertion point within the original |
882 | // loop, just before opToHoist. |
883 | OpBuilder::InsertionGuard g(rewriter); |
884 | rewriter.setInsertionPoint(opToHoist); |
885 | |
886 | Location loc = opToHoist->getLoc(); |
887 | RankedTensorType paddedTensorType = opToHoist.getResultType(); |
888 | int paddedRank = paddedTensorType.getRank(); |
889 | |
890 | int64_t nPackedLoops = packingResult.clonedLoopIvs.size(); |
891 | LLVM_DEBUG(DBGS() << "nPackedLoops: " << nPackedLoops << " loops\n" ); |
892 | |
893 | scf::ForOp outerLoop = analysis.outermostEnclosingForOp; |
894 | ArrayRef<scf::ForOp> packingLoops = analysis.packingLoops; |
895 | |
896 | Value hoistedPackedTensor; |
897 | SmallVector<Value> loopIterationCounts; |
898 | SmallVector<OpFoldResult> offsets(nPackedLoops + paddedRank, |
899 | rewriter.getIndexAttr(0)); |
900 | if (nPackedLoops > 0) { |
901 | loopIterationCounts = |
902 | llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) { |
903 | return buildLoopIterationCount(rewriter, outerLoop, |
904 | cast<scf::ForOp>(loop)); |
905 | })); |
906 | // Assert all loop iteration counts can be computed. |
907 | if (llvm ::any_of(Range&: loopIterationCounts, P: [](Value v) { return !v; })) |
908 | llvm_unreachable("loop independence prerequisite not met" ); |
909 | |
910 | // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0]. |
911 | std::copy(first: loopIterationCounts.begin(), last: loopIterationCounts.end(), |
912 | result: offsets.begin()); |
913 | hoistedPackedTensor = |
914 | scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front()) |
915 | ->getResult(0); |
916 | } else { |
917 | // If no loops were created, this is just hoisting without packing. |
918 | hoistedPackedTensor = bvm.lookup(opToHoist.getResult()); |
919 | } |
920 | |
921 | LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n" ); |
922 | |
923 | // If the consumer of `padOp` was a `forOp`, propagate through iter args. |
924 | scf::ForOp forOp = analysis.padConsumingForOp; |
925 | if (forOp) { |
926 | return padThroughLoopIterArg(rewriter, opToHoist, hoistedPackedTensor, |
927 | analysis.sliceOp, forOp); |
928 | } |
929 | |
930 | // offsets = [maybe_leading_ivs, 0 .. 0]. |
931 | // sizes = [1 .. 1, transposedShape] (defined above). |
932 | // strides = [1 .. 1] (defined above) |
933 | return rewriter.create<tensor::ExtractSliceOp>( |
934 | loc, transposedTensorType, hoistedPackedTensor, offsets, |
935 | packingResult.sizes, packingResult.strides); |
936 | } |
937 | |
938 | FailureOr<Value> mlir::linalg::hoistPaddingOnTensors( |
939 | RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, |
940 | ArrayRef<int64_t> transposeVector, tensor::PadOp &hoistedOp, |
941 | SmallVectorImpl<GenericOp> &transposeOps) { |
942 | LLVM_DEBUG(DBGS() << "\n" ; DBGS() << " Try to hoist " << *(opToHoist) << "\n" ; |
943 | DBGS() << " by " << numLoops << " loops\n" ); |
944 | |
945 | HoistPaddingAnalysis analysis(opToHoist, numLoops); |
946 | analysis.enableHoistPadding(rewriter); |
947 | analysis.finalizeHoistPaddingAnalysis(); |
948 | if (!analysis.isValid()) { |
949 | LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n" ); |
950 | return failure(); |
951 | } |
952 | |
953 | /// Construct the packing loop nest. |
954 | IRMapping bvm; |
955 | FailureOr<PackingResult> packingResult = buildPackingLoopNestImpl( |
956 | rewriter, bvm, opToHoist, transposeVector, analysis); |
957 | if (failed(result: packingResult)) { |
958 | LLVM_DEBUG(DBGS() << "--buildPackingLoopNestImpl failed -> Skip\n" ); |
959 | return failure(); |
960 | } |
961 | |
962 | if (!transposeVector.empty()) |
963 | transposeOps.push_back(packingResult->maybeTransposeOp); |
964 | |
965 | FailureOr<RankedTensorType> transposedTensorType = |
966 | tensor::computeTransposedType(rankedTensorType: opToHoist.getResultType(), transposeVector); |
967 | assert(succeeded(transposedTensorType) && "unexpected failure in type" ); |
968 | |
969 | // Now the packed tensor is ready, replace the original padding op by a |
970 | // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. |
971 | Value newResult = |
972 | replaceByPackingResult(rewriter, bvm, opToHoist, *transposedTensorType, |
973 | analysis, *packingResult); |
974 | |
975 | Location loc = opToHoist->getLoc(); |
976 | RankedTensorType paddedTensorType = opToHoist.getResultType(); |
977 | if (!transposeVector.empty()) { |
978 | OpBuilder::InsertionGuard g(rewriter); |
979 | rewriter.setInsertionPointAfter(newResult.getDefiningOp()); |
980 | // Transpose the packed tensor back to the original storage order. |
981 | Value emptyTensor = rewriter.create<tensor::EmptyOp>( |
982 | loc, paddedTensorType.getShape(), paddedTensorType.getElementType()); |
983 | GenericOp unTransposeOp = |
984 | makeTransposeOp(rewriter, loc, newResult, emptyTensor, transposeVector); |
985 | newResult = unTransposeOp.getResult(0); |
986 | transposeOps.push_back(unTransposeOp); |
987 | } |
988 | |
989 | LLVM_DEBUG(DBGS() << "newResult: " << newResult << "\n" ); |
990 | LLVM_DEBUG( |
991 | DBGS() << "After hoisting: " |
992 | << newResult.getDefiningOp()->getParentOfType<func::FuncOp>() |
993 | << "\n" ); |
994 | |
995 | // Make the newly cloned `opToHoist` available to the caller. |
996 | hoistedOp = packingResult->hoistedPadOp; |
997 | |
998 | LLVM_DEBUG(DBGS() << "--SUCCESS\n" ); |
999 | return newResult; |
1000 | } |
1001 | |
1002 | FailureOr<Value> |
1003 | mlir::linalg::hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops, |
1004 | ArrayRef<int64_t> transposeVector, |
1005 | tensor::PadOp &hoistedOp, |
1006 | SmallVectorImpl<GenericOp> &transposeOps) { |
1007 | IRRewriter rewriter(opToHoist.getContext()); |
1008 | return hoistPaddingOnTensors(rewriter, opToHoist, numLoops, transposeVector, |
1009 | hoistedOp, transposeOps); |
1010 | } |
1011 | |