1 | //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the tiling using TilingInterface. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" |
14 | |
15 | #include "mlir/Analysis/SliceAnalysis.h" |
16 | #include "mlir/Analysis/TopologicalSortUtils.h" |
17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
21 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
22 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
24 | #include "mlir/IR/Dominance.h" |
25 | #include "mlir/IR/Matchers.h" |
26 | #include "mlir/IR/PatternMatch.h" |
27 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
28 | #include "mlir/Interfaces/TilingInterface.h" |
29 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
31 | #include "llvm/ADT/ScopeExit.h" |
32 | #include "llvm/ADT/TypeSwitch.h" |
33 | #include "llvm/Support/Debug.h" |
34 | #include <optional> |
35 | |
36 | #define DEBUG_TYPE "tile-using-interface" |
37 | |
38 | using namespace mlir; |
39 | |
40 | scf::SCFTilingOptions & |
41 | scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { |
42 | assert(!tileSizeComputationFunction && "tile sizes already set"); |
43 | auto tileSizes = llvm::to_vector(Range&: ts); |
44 | tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { |
45 | return tileSizes; |
46 | }; |
47 | return *this; |
48 | } |
49 | |
50 | scf::SCFTilingOptions & |
51 | scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) { |
52 | assert(!numThreadsComputationFunction && "num tiles already set"); |
53 | auto numThreads = llvm::to_vector(Range&: nt); |
54 | numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { |
55 | return numThreads; |
56 | }; |
57 | return *this; |
58 | } |
59 | |
60 | /// Helper method to adjust the interchange vector to match the iteration |
61 | /// domain. |
62 | static SmallVector<int64_t> |
63 | fillInterchangeVector(ArrayRef<int64_t> interchangeVector, |
64 | size_t iterationDomainSize) { |
65 | SmallVector<int64_t> filledVector = llvm::to_vector(Range&: interchangeVector); |
66 | if (filledVector.size() < iterationDomainSize) { |
67 | auto range = llvm::seq<int64_t>(Begin: filledVector.size(), End: iterationDomainSize); |
68 | filledVector.append(in_start: range.begin(), in_end: range.end()); |
69 | } |
70 | if (filledVector.size() > iterationDomainSize) |
71 | filledVector.resize(N: iterationDomainSize); |
72 | return filledVector; |
73 | } |
74 | |
75 | //===----------------------------------------------------------------------===// |
76 | // tileUsingSCF implementation. |
77 | //===----------------------------------------------------------------------===// |
78 | |
79 | /// Verify the tile size options are set in a consistent manner. |
80 | static LogicalResult |
81 | verifyTileSizeOptions(RewriterBase &rewriter, Location loc, |
82 | const scf::SCFTilingOptions &options) { |
83 | // Specifying number of threads is only supported on `scf.forall` op. |
84 | if (options.numThreadsComputationFunction && |
85 | options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { |
86 | return rewriter.notifyMatchFailure( |
87 | arg&: loc, msg: "number of threads can only by specified when loop type is " |
88 | "set to use `scf.forall`"); |
89 | } |
90 | |
91 | // If specified, check that the interchange vector is a permutation. |
92 | if (!options.interchangeVector.empty()) { |
93 | if (!isPermutationVector(interchange: options.interchangeVector)) { |
94 | return rewriter.notifyMatchFailure( |
95 | arg&: loc, msg: "invalid interchange vector, not a permutation of the entire " |
96 | "iteration space"); |
97 | } |
98 | } |
99 | return success(); |
100 | } |
101 | |
102 | /// Method to instantiate the tile sizes and/or number of threads specified |
103 | /// by the user. |
104 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> |
105 | getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, |
106 | ArrayRef<Range> iterationDomain, |
107 | const scf::SCFTilingOptions &options) { |
108 | OpFoldResult zero = rewriter.getIndexAttr(0); |
109 | SmallVector<OpFoldResult> tileSizes, numThreads; |
110 | size_t numLoops = iterationDomain.size(); |
111 | |
112 | // Check whether the number of tiles to use is specified. |
113 | if (options.numThreadsComputationFunction) { |
114 | numThreads = options.numThreadsComputationFunction(rewriter, op); |
115 | numThreads.resize(N: numLoops, NV: zero); |
116 | |
117 | // If the number of tiles is also specified, use that. |
118 | if (options.tileSizeComputationFunction) { |
119 | tileSizes = options.tileSizeComputationFunction(rewriter, op); |
120 | tileSizes.resize(N: numLoops, NV: zero); |
121 | return {tileSizes, numThreads}; |
122 | } |
123 | |
124 | // Compute the tile sizes from the iteration domain and number |
125 | // of tiles as follows |
126 | // - niters = ceilDiv(ub - lb, step) |
127 | // - tileSize = ceilDiv(niters, numThreads) |
128 | AffineExpr s0, s1, s2; |
129 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2); |
130 | // TODO: The step here is assumed to be 1. |
131 | AffineExpr numItersExpr = (s1 - s0); |
132 | AffineExpr tileSizeExpr = numItersExpr.ceilDiv(other: s2); |
133 | tileSizes.resize(N: numLoops, NV: zero); |
134 | for (auto [index, range, nt] : |
135 | llvm::enumerate(First&: iterationDomain, Rest&: numThreads)) { |
136 | if (isZeroInteger(v: nt)) |
137 | continue; |
138 | |
139 | tileSizes[index] = affine::makeComposedFoldedAffineApply( |
140 | rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt}); |
141 | } |
142 | tileSizes.resize(N: numLoops, NV: zero); |
143 | return {tileSizes, numThreads}; |
144 | } |
145 | |
146 | // Enforce the convention that "tiling by zero" |
147 | // skips tiling a particular dimension. This convention is significantly |
148 | // simpler to handle instead of adjusting affine maps to account for missing |
149 | // dimensions. |
150 | assert(options.tileSizeComputationFunction && |
151 | "expected tile sizes to be specified"); |
152 | tileSizes = options.tileSizeComputationFunction(rewriter, op); |
153 | tileSizes.resize(N: numLoops, NV: zero); |
154 | |
155 | return {tileSizes, numThreads}; |
156 | } |
157 | |
158 | /// Checks if any of the tiled loops are not parallel. |
159 | static void checkSafeToTileToForall(TilingInterface op, |
160 | ArrayRef<OpFoldResult> tileSizes, |
161 | ArrayRef<OpFoldResult> numThreads) { |
162 | auto iterators = op.getLoopIteratorTypes(); |
163 | assert(iterators.size() == tileSizes.size() && |
164 | "expected as many tile size values as number of loops"); |
165 | assert((numThreads.empty() || (numThreads.size() == iterators.size())) && |
166 | "when specified, expected number of threads to use for each loop"); |
167 | |
168 | for (auto [index, iterator, tileSize] : |
169 | llvm::enumerate(iterators, tileSizes)) { |
170 | // If num threads is specified, check that it is greater than one only for |
171 | // parallel dimensions. |
172 | if (!numThreads.empty()) { |
173 | if (std::optional<int64_t> constNumThreads = |
174 | getConstantIntValue(numThreads[index])) { |
175 | if (constNumThreads.value() > 1 && |
176 | iterator != utils::IteratorType::parallel) { |
177 | op.emitWarning() << "tiling is not thread safe at axis #"<< index; |
178 | } |
179 | } |
180 | continue; |
181 | } |
182 | |
183 | if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) { |
184 | if (constTileSize.value() > 0 && |
185 | iterator != utils::IteratorType::parallel) { |
186 | op.emitWarning() << "tiling is not thread safe at axis #"<< index; |
187 | } |
188 | } |
189 | } |
190 | } |
191 | |
192 | /// Check if `stride` evenly divides the trip count `size - offset`. |
193 | static bool tileDividesIterationDomain(Range loopRange) { |
194 | std::optional<int64_t> offsetAsInt = getConstantIntValue(ofr: loopRange.offset); |
195 | if (!offsetAsInt) |
196 | return false; |
197 | std::optional<int64_t> sizeAsInt = getConstantIntValue(ofr: loopRange.size); |
198 | if (!sizeAsInt) |
199 | return false; |
200 | std::optional<int64_t> strideAsInt = getConstantIntValue(ofr: loopRange.stride); |
201 | if (!strideAsInt) |
202 | return false; |
203 | return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); |
204 | } |
205 | |
206 | /// Returns the bounded tile size given the current `offset`, `loopRange` and |
207 | /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. |
208 | static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, |
209 | Range loopRange, OpFoldResult offset, |
210 | OpFoldResult tileSize) { |
211 | std::optional<int64_t> ts = getConstantIntValue(ofr: tileSize); |
212 | if (ts && ts.value() == 1) |
213 | return tileSize; |
214 | |
215 | if (tileDividesIterationDomain( |
216 | loopRange: Range{.offset: loopRange.offset, .size: loopRange.size, .stride: tileSize})) |
217 | return tileSize; |
218 | |
219 | // The tile size to use (to avoid out of bounds access) is minimum of |
220 | // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled |
221 | // loop. |
222 | AffineExpr s0, s1, d0; |
223 | bindDims(ctx: b.getContext(), exprs&: d0); |
224 | bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1); |
225 | AffineMap minMap = AffineMap::get(dimCount: 1, symbolCount: 2, results: {s0 - d0, s1}, context: b.getContext()); |
226 | Value size = getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange.size); |
227 | return affine::makeComposedFoldedAffineMin( |
228 | b, loc, map: minMap, operands: SmallVector<OpFoldResult>{offset, size, tileSize}); |
229 | } |
230 | |
231 | /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less |
232 | /// than `iterationSize`. |
233 | static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, |
234 | OpFoldResult numThreads, |
235 | OpFoldResult iterationSize) { |
236 | std::optional<int64_t> tileSizeConst = getConstantIntValue(ofr: tileSize); |
237 | std::optional<int64_t> numThreadsConst = getConstantIntValue(ofr: numThreads); |
238 | std::optional<int64_t> iterSizeConst = getConstantIntValue(ofr: iterationSize); |
239 | if (!tileSizeConst || !numThreadsConst || !iterSizeConst) |
240 | return false; |
241 | return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; |
242 | } |
243 | |
244 | /// Compute the `OpFoldResult`s that represents the multi-dimensional |
245 | /// `offset`s and `size`s of the tile of the iteration space that the |
246 | /// innermost loop body of the generated tiled loops corresponds to. |
247 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> |
248 | getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, |
249 | ArrayRef<Range> iterationDomain, |
250 | ArrayRef<OpFoldResult> tileSizes, |
251 | ArrayRef<OpFoldResult> numThreads) { |
252 | SmallVector<OpFoldResult> offsets, sizes; |
253 | int materializedLoopNum = 0; |
254 | |
255 | if (!numThreads.empty()) { |
256 | AffineExpr d0, d1, s0, s1; |
257 | AffineExpr offsetExpr, residualTileSizeExpr; |
258 | bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1); |
259 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1); |
260 | offsetExpr = d0 + d1 * s0; |
261 | residualTileSizeExpr = s1 - (d0 + d1 * s0); |
262 | |
263 | for (auto [nt, tileSize, loopRange] : |
264 | llvm::zip_equal(t&: numThreads, u&: tileSizes, args&: iterationDomain)) { |
265 | |
266 | // Non-tiled cases, set the offset and size to the |
267 | // `loopRange.offset/size`. |
268 | if (isZeroInteger(v: nt)) { |
269 | offsets.push_back(Elt: loopRange.offset); |
270 | sizes.push_back(Elt: loopRange.size); |
271 | continue; |
272 | } |
273 | |
274 | Value iv = ivs[materializedLoopNum++]; |
275 | OpFoldResult offset = affine::makeComposedFoldedAffineApply( |
276 | b&: rewriter, loc, expr: offsetExpr, |
277 | operands: ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize}); |
278 | OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( |
279 | b&: rewriter, loc, expr: residualTileSizeExpr, |
280 | operands: {loopRange.offset, nt, tileSize, loopRange.size}); |
281 | |
282 | OpFoldResult size = tileSize; |
283 | if (!isZeroInteger(v: residualTileSize)) { |
284 | OpFoldResult sizeMinusOffsetPerThread = |
285 | affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 - d0, |
286 | operands: {offset, loopRange.size}); |
287 | size = affine::makeComposedFoldedAffineMin( |
288 | b&: rewriter, loc, |
289 | map: AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext()), |
290 | operands: {sizeMinusOffsetPerThread, tileSize}); |
291 | } |
292 | |
293 | // Consider the case where the original loop was `[0, 100)`. |
294 | // If number of threads are `7`, the tile size would be computed as |
295 | // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) |
296 | // - `offset = 0 + 6 * 15 = 105` |
297 | // - `tileSize = min(15, 100 - 105) = -5` |
298 | // To avoid negative tile sizes, we need to do a further |
299 | // `nonNegativeTileSize = affine.max(0, tileSize)`. |
300 | // This `max` can be avoided if |
301 | // `offset + tileSize * (numThreads - 1) < (ub - lb)` |
302 | if (!canOmitTileOffsetInBoundsCheck(tileSize, numThreads: nt, iterationSize: loopRange.size)) { |
303 | AffineMap maxMap = |
304 | AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext()); |
305 | size = affine::makeComposedFoldedAffineMax( |
306 | rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); |
307 | } |
308 | |
309 | offsets.push_back(Elt: offset); |
310 | sizes.push_back(Elt: size); |
311 | } |
312 | return {offsets, sizes}; |
313 | } else { |
314 | for (auto [tileSize, loopRange] : |
315 | llvm::zip_equal(t&: tileSizes, u&: iterationDomain)) { |
316 | |
317 | // Non-tiled cases, set the offset and size to the |
318 | // `loopRange.offset/size`. |
319 | if (isZeroInteger(v: tileSize)) { |
320 | offsets.push_back(Elt: loopRange.offset); |
321 | sizes.push_back(Elt: loopRange.size); |
322 | continue; |
323 | } |
324 | |
325 | Value iv = ivs[materializedLoopNum++]; |
326 | OpFoldResult offset = getAsOpFoldResult(val: iv); |
327 | offsets.push_back(Elt: offset); |
328 | OpFoldResult size = |
329 | getBoundedTileSize(b&: rewriter, loc, loopRange, offset, tileSize); |
330 | sizes.push_back(Elt: size); |
331 | } |
332 | return {offsets, sizes}; |
333 | } |
334 | } |
335 | |
336 | /// Function to return the bounds of the loops to be generated. |
337 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, |
338 | SmallVector<OpFoldResult>> |
339 | getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
340 | ArrayRef<OpFoldResult> tileSizes) { |
341 | SmallVector<OpFoldResult> lbs, ubs, steps; |
342 | for (auto [loopRange, tileSize] : llvm::zip_equal(t&: loopRanges, u&: tileSizes)) { |
343 | // No loop if the tile size is 0. |
344 | if (isZeroInteger(v: tileSize)) |
345 | continue; |
346 | lbs.push_back(Elt: loopRange.offset); |
347 | ubs.push_back(Elt: loopRange.size); |
348 | steps.push_back(Elt: tileSize); |
349 | } |
350 | return {lbs, ubs, steps}; |
351 | } |
352 | |
353 | /// A function that allows returning additional yielded values during |
354 | /// `yieldTiledValuesAndReplace`. |
355 | /// - `ivs` induction variable for the loop. |
356 | /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. |
357 | /// - `tiledValues` the tiled values to return. Must be of same size as |
358 | /// `newbbArgs`, each element of this array is inserted into the corresponding |
359 | /// element in `newbbArgs`. |
360 | /// - `resultOffsets` is of the same size as `tiledValues` and represents |
361 | /// the offsets to use when inserting corresponding element from `tiledValues` |
362 | /// into the element from `newBbArgs`. |
363 | /// - `resultSizes` is of the same size as `tiledValues` and represents |
364 | /// the size of the corresponding element from `tiledValues` inserted into |
365 | /// the element from `newBbArgs`. |
366 | /// In case the method needs to return `failure()` the method is expected |
367 | /// to clean up any inserted operations. |
368 | using YieldTiledValuesFn = std::function<LogicalResult( |
369 | RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, |
370 | SmallVector<Value> &tiledValues, |
371 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
372 | SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; |
373 | |
374 | /// Clones the operation and updates the destination if the operation |
375 | /// implements the `DestinationStyleOpInterface`. |
376 | static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, |
377 | Operation *op, |
378 | ValueRange newDestArgs) { |
379 | Operation *clonedOp = rewriter.clone(op&: *op); |
380 | if (newDestArgs.empty()) |
381 | return clonedOp; |
382 | if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) |
383 | destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); |
384 | return clonedOp; |
385 | } |
386 | |
387 | /// Generate the tile-loop nest using `scf.for` operation. |
388 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
389 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
390 | /// - `destinationTensors` are the init values to use for the outer most loop. |
391 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
392 | /// most |
393 | /// loop. |
394 | /// - `loops` is an in-out parameter into which the generated loops are |
395 | /// populated. |
396 | static LogicalResult generateLoopNestUsingForOp( |
397 | RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
398 | ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, |
399 | YieldTiledValuesFn yieldTiledValuesFn, |
400 | SmallVector<LoopLikeOpInterface> &loops) { |
401 | assert(!loopRanges.empty() && "unexpected empty loop ranges"); |
402 | assert(loopRanges.size() == tileSizes.size() && |
403 | "expected as many tile sizes as loop ranges"); |
404 | OpBuilder::InsertionGuard guard(rewriter); |
405 | |
406 | SmallVector<OpFoldResult> lbs, ubs, steps; |
407 | std::tie(args&: lbs, args&: ubs, args&: steps) = |
408 | getLoopBounds(rewriter, loc, loopRanges, tileSizes); |
409 | SmallVector<Value> lbVals = |
410 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: lbs); |
411 | SmallVector<Value> ubVals = |
412 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: ubs); |
413 | SmallVector<Value> stepVals = |
414 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: steps); |
415 | |
416 | SmallVector<Value> ivs; |
417 | for (auto [lb, ub, step] : llvm::zip_equal(t&: lbVals, u&: ubVals, args&: stepVals)) { |
418 | auto loop = |
419 | rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, |
420 | [](OpBuilder &bodyBuilder, Location bodyLoc, |
421 | Value iv, ValueRange /*iterArgs*/) {}); |
422 | loops.push_back(loop); |
423 | ivs.push_back(Elt: loop.getInductionVar()); |
424 | rewriter.setInsertionPointToEnd(loop.getBody()); |
425 | destinationTensors = loop.getRegionIterArgs(); |
426 | } |
427 | |
428 | SmallVector<Value> tiledResults; |
429 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
430 | if (failed(Result: yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, |
431 | tiledResults, resultOffsets, resultSizes))) { |
432 | return rewriter.notifyMatchFailure( |
433 | arg&: loc, msg: "failed to generate inner tile loop body"); |
434 | } |
435 | if (loops.empty()) |
436 | return success(); |
437 | |
438 | assert(tiledResults.size() == destinationTensors.size() && |
439 | "Number of results of body should be equal to number of iter args"); |
440 | |
441 | // 6. Yield all the results of the tiled operation. |
442 | SmallVector<Value> yieldedValues; |
443 | for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : |
444 | llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets, |
445 | args&: resultSizes)) { |
446 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
447 | rewriter.getIndexAttr(1)); |
448 | auto insertSlice = rewriter.create<tensor::InsertSliceOp>( |
449 | loc, tiledValue, destinationTensor, resultOffset, resultSize, |
450 | resultStride); |
451 | yieldedValues.push_back(Elt: insertSlice); |
452 | } |
453 | rewriter.create<scf::YieldOp>(loc, yieldedValues); |
454 | |
455 | // Add the scf.yield operations for all the outer loops. |
456 | for (auto [outerLoop, innerLoop] : |
457 | llvm::zip_equal(MutableArrayRef(loops).drop_back(), |
458 | MutableArrayRef(loops).drop_front())) { |
459 | rewriter.setInsertionPointToEnd( |
460 | cast<scf::ForOp>(outerLoop.getOperation()).getBody()); |
461 | rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); |
462 | } |
463 | return success(); |
464 | } |
465 | |
466 | /// Generate the tile-loop nest using `scf.forall` operation. |
467 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
468 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
469 | /// - `destinationTensors` are the init values to use for the outer most loop. |
470 | /// - `mappingVector` is the mapping attributes to use for loop construction. |
471 | /// Can be empty. |
472 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
473 | /// most |
474 | /// loop. |
475 | /// - `loops` is an in-out parameter into which the generated loops are |
476 | /// populated. |
477 | static LogicalResult generateLoopNestUsingForallOp( |
478 | RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
479 | ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, |
480 | ArrayRef<Attribute> mappingVector, ValueRange destinationTensors, |
481 | YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { |
482 | assert(!loopRanges.empty() && "unexpected empty loop ranges"); |
483 | assert(loopRanges.size() == tileSizes.size() && |
484 | "expected as many tile sizes as loop ranges"); |
485 | OpBuilder::InsertionGuard guard(rewriter); |
486 | |
487 | std::optional<ArrayAttr> mappingAttr; |
488 | if (!mappingVector.empty()) |
489 | mappingAttr = rewriter.getArrayAttr(mappingVector); |
490 | |
491 | scf::ForallOp forallOp; |
492 | bool useNumThreads = !numThreads.empty(); |
493 | |
494 | if (useNumThreads) { |
495 | // Prune the zero numthreads. |
496 | SmallVector<OpFoldResult> nonZeroNumThreads; |
497 | for (auto nt : numThreads) { |
498 | if (isZeroInteger(v: nt)) |
499 | continue; |
500 | nonZeroNumThreads.push_back(Elt: nt); |
501 | } |
502 | forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads, |
503 | destinationTensors, mappingAttr); |
504 | } else { |
505 | SmallVector<OpFoldResult> lbs, ubs, steps; |
506 | std::tie(args&: lbs, args&: ubs, args&: steps) = |
507 | getLoopBounds(rewriter, loc, loopRanges, tileSizes); |
508 | forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, |
509 | destinationTensors, mappingAttr); |
510 | } |
511 | loops.push_back(forallOp); |
512 | |
513 | rewriter.setInsertionPoint(forallOp.getTerminator()); |
514 | destinationTensors = forallOp.getRegionOutArgs(); |
515 | |
516 | SmallVector<Value> tiledResults; |
517 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
518 | if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), |
519 | destinationTensors, tiledResults, resultOffsets, |
520 | resultSizes))) |
521 | return rewriter.notifyMatchFailure(arg&: loc, msg: "failed to generate loop body"); |
522 | |
523 | rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); |
524 | for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : |
525 | llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets, |
526 | args&: resultSizes)) { |
527 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
528 | rewriter.getIndexAttr(1)); |
529 | |
530 | rewriter.create<tensor::ParallelInsertSliceOp>( |
531 | loc, tiledValue, destinationTensor, resultOffset, resultSize, |
532 | resultStride); |
533 | } |
534 | return success(); |
535 | } |
536 | |
537 | /// Generate the tile-loop nest using the loop construct specifed in `options`. |
538 | /// - `options`: Tiling options specified. |
539 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
540 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
541 | /// - `destinationTensors` are the init values to use for the outer most loop. |
542 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
543 | /// most |
544 | /// loop. |
545 | /// - `loops` is an in-out parameter into which the generated loops are |
546 | /// populated. |
547 | static LogicalResult generateLoopNest( |
548 | RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, |
549 | ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes, |
550 | ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors, |
551 | YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { |
552 | // If the tile sizes are all zero, no loops are generated. Just call the |
553 | // callback function to handle untiled case. |
554 | if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) { |
555 | SmallVector<Value> tiledResults; |
556 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
557 | return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, |
558 | tiledResults, resultOffsets, resultSizes); |
559 | } |
560 | if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { |
561 | return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, |
562 | destinationTensors, tiledBodyFn, loops); |
563 | } |
564 | if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { |
565 | return generateLoopNestUsingForallOp( |
566 | rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, |
567 | destinationTensors, tiledBodyFn, loops); |
568 | } |
569 | return rewriter.notifyMatchFailure(arg&: loc, msg: "unhandled loop type"); |
570 | } |
571 | |
572 | static FailureOr<SmallVector<Value>> |
573 | createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, |
574 | ArrayRef<OpFoldResult> tileSizes, |
575 | const scf::SCFTilingOptions &options) { |
576 | SmallVector<Value> initTensors; |
577 | Location loc = op->getLoc(); |
578 | switch (options.reductionStrategy) { |
579 | case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: |
580 | if (failed(tensor::getOrCreateDestinations(b&: rewriter, loc, op: op, result&: initTensors))) |
581 | return failure(); |
582 | return initTensors; |
583 | case scf::SCFTilingOptions::ReductionTilingStrategy:: |
584 | PartialReductionOuterReduction: { |
585 | auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); |
586 | if (!redOp) { |
587 | return rewriter.notifyMatchFailure( |
588 | op, "PartialReductionOuterReduction tiling strategy is only supported" |
589 | "for operations implementing PartialReductionOpInterface"); |
590 | } |
591 | // Get reduction dimensions. |
592 | // TODO: PartialReductionOpInterface should really query TilingInterface |
593 | // itself and find reduction dimensions. |
594 | SmallVector<int> reductionDims; |
595 | for (auto [idx, iteratorType] : |
596 | llvm::enumerate(op.getLoopIteratorTypes())) { |
597 | if (iteratorType == utils::IteratorType::reduction) |
598 | reductionDims.push_back(idx); |
599 | } |
600 | return redOp.generateInitialTensorForPartialReduction( |
601 | rewriter, loc, tileSizes, reductionDims); |
602 | } |
603 | default: |
604 | return rewriter.notifyMatchFailure(op, |
605 | "unhandled reduction tiling strategy"); |
606 | } |
607 | } |
608 | |
609 | static FailureOr<TilingResult> |
610 | getTiledImplementation(RewriterBase &rewriter, TilingInterface op, |
611 | ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets, |
612 | ArrayRef<OpFoldResult> sizes, |
613 | const scf::SCFTilingOptions &options) { |
614 | switch (options.reductionStrategy) { |
615 | case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: |
616 | return op.getTiledImplementation(rewriter, offsets, sizes); |
617 | case scf::SCFTilingOptions::ReductionTilingStrategy:: |
618 | PartialReductionOuterReduction: { |
619 | auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); |
620 | if (!redOp) { |
621 | return rewriter.notifyMatchFailure( |
622 | op, "PartialReductionOuterReduction tiling strategy is only " |
623 | "supported for operations " |
624 | "implementing PartialReductionOpInterface"); |
625 | } |
626 | // Get reduction dimensions. |
627 | // TODO: PartialReductionOpInterface should really query TilingInterface |
628 | // itself and find reduction dimensions. |
629 | SmallVector<int> reductionDims; |
630 | for (auto [idx, iteratorType] : |
631 | llvm::enumerate(op.getLoopIteratorTypes())) { |
632 | if (iteratorType == utils::IteratorType::reduction) |
633 | reductionDims.push_back(idx); |
634 | } |
635 | return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg, |
636 | offsets, sizes, reductionDims); |
637 | } |
638 | default: |
639 | return rewriter.notifyMatchFailure(op, |
640 | "unhandled reduction tiling strategy"); |
641 | } |
642 | } |
643 | |
644 | static LogicalResult |
645 | getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, |
646 | TilingInterface op, ArrayRef<OpFoldResult> offsets, |
647 | ArrayRef<OpFoldResult> sizes, |
648 | SmallVector<OpFoldResult> &resultOffset, |
649 | SmallVector<OpFoldResult> &resultSize, |
650 | const scf::SCFTilingOptions &options) { |
651 | |
652 | switch (options.reductionStrategy) { |
653 | case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: |
654 | return op.getResultTilePosition(rewriter, index, offsets, sizes, |
655 | resultOffset, resultSize); |
656 | case scf::SCFTilingOptions::ReductionTilingStrategy:: |
657 | PartialReductionOuterReduction: { |
658 | auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); |
659 | if (!redOp) { |
660 | return rewriter.notifyMatchFailure( |
661 | op, "PartialReductionOuterReduction tiling strategy is only supported" |
662 | "for operations implementing PartialReductionOpInterface"); |
663 | } |
664 | // Get reduction dimensions. |
665 | // TODO: PartialReductionOpInterface should really query TilingInterface |
666 | // itself and find reduction dimensions. |
667 | SmallVector<int> reductionDims; |
668 | for (auto [idx, iteratorType] : |
669 | llvm::enumerate(op.getLoopIteratorTypes())) { |
670 | if (iteratorType == utils::IteratorType::reduction) |
671 | reductionDims.push_back(idx); |
672 | } |
673 | return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, |
674 | resultOffset, resultSize, |
675 | reductionDims); |
676 | } |
677 | default: |
678 | return rewriter.notifyMatchFailure(op, |
679 | "unhandled reduction tiling strategy"); |
680 | } |
681 | } |
682 | |
683 | static FailureOr<MergeResult> |
684 | mergeTilingResults(RewriterBase &rewriter, TilingInterface op, |
685 | ValueRange partialResults, |
686 | const scf::SCFTilingOptions &options) { |
687 | switch (options.reductionStrategy) { |
688 | case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: |
689 | // No need to merge results for reduction tiling strategy. |
690 | return MergeResult{.mergeOps: {}, .replacements: partialResults}; |
691 | case scf::SCFTilingOptions::ReductionTilingStrategy:: |
692 | PartialReductionOuterReduction: { |
693 | auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); |
694 | if (!redOp) { |
695 | return rewriter.notifyMatchFailure( |
696 | op, "PartialReductionOuterReduction tiling strategy is only " |
697 | "supported for operations " |
698 | "implementing PartialReductionOpInterface"); |
699 | } |
700 | // Get reduction dimensions. |
701 | // TODO: PartialReductionOpInterface should really query TilingInterface |
702 | // itself and find reduction dimensions. |
703 | SmallVector<int> reductionDims; |
704 | for (auto [idx, iteratorType] : |
705 | llvm::enumerate(op.getLoopIteratorTypes())) { |
706 | if (iteratorType == utils::IteratorType::reduction) |
707 | reductionDims.push_back(idx); |
708 | } |
709 | return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, |
710 | reductionDims); |
711 | } |
712 | default: |
713 | return rewriter.notifyMatchFailure(op, |
714 | "unhandled reduction tiling strategy"); |
715 | } |
716 | } |
717 | |
718 | /// Append the specified additional `newInitOperands` operands to the |
719 | /// loops existing `init` operands (or similar), and replace `loopOp` with |
720 | /// the new loop that has the additional init operands. The loop body of |
721 | /// this loop is moved over to the new loop. `yieldTiledValuesFn` |
722 | /// is called to get the new tiled values returned, and the offset |
723 | /// and sizes at which the tiled value is inserted into the |
724 | /// new region iter_args that correspond to the newly added init operands. |
725 | template <typename LoopType> |
726 | FailureOr<LoopLikeOpInterface> |
727 | yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, |
728 | ValueRange newInitOperands, |
729 | YieldTiledValuesFn yieldTiledValuesFn) { |
730 | return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); |
731 | } |
732 | |
733 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. |
734 | template <> |
735 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( |
736 | scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
737 | YieldTiledValuesFn yieldTiledValuesFn) { |
738 | OpBuilder::InsertionGuard g(rewriter); |
739 | Location loc = loopOp.getLoc(); |
740 | rewriter.setInsertionPoint(loopOp); |
741 | |
742 | auto inits = llvm::to_vector(loopOp.getInitArgs()); |
743 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
744 | auto newLoop = rewriter.create<scf::ForOp>( |
745 | loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), |
746 | inits, [](OpBuilder &, Location, Value, ValueRange) {}); |
747 | |
748 | // Move the loop body to the new op. |
749 | Block *loopBody = loopOp.getBody(); |
750 | Block *newLoopBody = newLoop.getBody(); |
751 | rewriter.mergeBlocks( |
752 | loopBody, newLoopBody, |
753 | newLoopBody->getArguments().take_front(loopBody->getNumArguments())); |
754 | |
755 | auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); |
756 | rewriter.setInsertionPoint(yieldOp); |
757 | |
758 | SmallVector<Value> tiledValues; |
759 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
760 | ValueRange newRegionIterArgs = |
761 | newLoop.getRegionIterArgs().take_back(newInitOperands.size()); |
762 | if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), |
763 | newRegionIterArgs, tiledValues, resultOffsets, |
764 | resultSizes))) { |
765 | rewriter.eraseOp(newLoop); |
766 | return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); |
767 | } |
768 | |
769 | SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); |
770 | for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : |
771 | llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, |
772 | resultSizes)) { |
773 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
774 | rewriter.getIndexAttr(1)); |
775 | Value insert = rewriter.create<tensor::InsertSliceOp>( |
776 | yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, |
777 | resultStride); |
778 | newYieldValues.push_back(insert); |
779 | } |
780 | |
781 | rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); |
782 | rewriter.replaceOp(loopOp, |
783 | newLoop->getResults().take_front(loopOp.getNumResults())); |
784 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
785 | } |
786 | |
787 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` |
788 | template <> |
789 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( |
790 | scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
791 | YieldTiledValuesFn yieldTiledValuesFn) { |
792 | OpBuilder::InsertionGuard g(rewriter); |
793 | Location loc = loopOp.getLoc(); |
794 | rewriter.setInsertionPoint(loopOp); |
795 | auto inits = llvm::to_vector(loopOp.getOutputs()); |
796 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
797 | auto newLoop = rewriter.create<scf::ForallOp>( |
798 | loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), |
799 | loopOp.getMixedStep(), inits, loopOp.getMapping(), |
800 | [](OpBuilder &, Location, ValueRange) {}); |
801 | |
802 | // Move the region of the current block to the newly created op. |
803 | Block *loopBody = loopOp.getBody(); |
804 | Block *newLoopBody = newLoop.getBody(); |
805 | rewriter.mergeBlocks( |
806 | loopBody, newLoopBody, |
807 | newLoopBody->getArguments().take_front(loopBody->getNumArguments())); |
808 | |
809 | auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); |
810 | rewriter.setInsertionPoint(terminator); |
811 | SmallVector<Value> tiledValues; |
812 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
813 | ValueRange regionIterArgs = |
814 | newLoop.getRegionIterArgs().take_back(newInitOperands.size()); |
815 | if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), |
816 | regionIterArgs, tiledValues, resultOffsets, |
817 | resultSizes))) { |
818 | rewriter.eraseOp(newLoop); |
819 | return rewriter.notifyMatchFailure(loopOp, |
820 | "failed to get yielded tiled values"); |
821 | } |
822 | |
823 | // Update the terminator. |
824 | rewriter.setInsertionPointToEnd(terminator.getBody()); |
825 | |
826 | for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( |
827 | tiledValues, regionIterArgs, resultOffsets, resultSizes)) { |
828 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
829 | rewriter.getIndexAttr(1)); |
830 | rewriter.create<tensor::ParallelInsertSliceOp>( |
831 | terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, |
832 | resultStride); |
833 | } |
834 | |
835 | rewriter.replaceOp(loopOp, |
836 | newLoop->getResults().take_front(loopOp.getNumResults())); |
837 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
838 | } |
839 | |
840 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for |
841 | /// `LoopLikeOpInterface`, that just dispatches to the implementation for each |
842 | /// supported loop type. |
843 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( |
844 | LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, |
845 | ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { |
846 | return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( |
847 | loopLikeOp.getOperation()) |
848 | .Case<scf::ForOp, scf::ForallOp>( |
849 | [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
850 | return yieldTiledValuesAndReplaceLoop( |
851 | loopOp, rewriter, newInitOperands, yieldTiledValuesFn); |
852 | }) |
853 | .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
854 | return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); |
855 | }); |
856 | } |
857 | |
858 | /// Method to add new init values to a loop nest. Updates `loops` in-place |
859 | /// with new loops that use the `newInitValues`. The outer-loops are updated |
860 | /// to yield the new result values of the inner loop. For the innermost loop, |
861 | /// the call back `getNewYields` is invoked to get the additional values to |
862 | /// yield form the innermost loop. |
863 | static LogicalResult addInitOperandsToLoopNest( |
864 | RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, |
865 | ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { |
866 | if (loops.empty()) |
867 | return success(); |
868 | OpBuilder::InsertionGuard g(rewriter); |
869 | rewriter.setInsertionPoint(loops.front()); |
870 | |
871 | SmallVector<Value> ivs; |
872 | for (auto &loop : loops.drop_back()) { |
873 | rewriter.setInsertionPoint(loop); |
874 | |
875 | // if loops.size() > 1 we assume that scf.for is used for the loops. |
876 | auto forLoop = cast<scf::ForOp>(loop.getOperation()); |
877 | |
878 | // Create a new loop with the new init values for this loop. |
879 | SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); |
880 | newInits.append(newInitValues.begin(), newInitValues.end()); |
881 | auto newLoop = rewriter.create<scf::ForOp>( |
882 | forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), |
883 | forLoop.getStep(), newInits, |
884 | [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); |
885 | |
886 | // Merge the body of the new loop with the body of the old loops. |
887 | SmallVector<Value> sourceBlockArgs; |
888 | sourceBlockArgs.push_back(newLoop.getInductionVar()); |
889 | auto newRegionIterArgs = newLoop.getRegionIterArgs(); |
890 | sourceBlockArgs.append( |
891 | newRegionIterArgs.begin(), |
892 | std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); |
893 | rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); |
894 | rewriter.replaceOp( |
895 | forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); |
896 | loop = newLoop; |
897 | ivs.push_back(newLoop.getInductionVar()); |
898 | newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); |
899 | } |
900 | |
901 | // Update the loop body of the innermost loop to get new yield values. |
902 | LoopLikeOpInterface innerMostLoop = loops.back(); |
903 | FailureOr<LoopLikeOpInterface> newInnerMostLoop = |
904 | yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, |
905 | getNewTiledYieldsFn); |
906 | |
907 | if (failed(newInnerMostLoop)) |
908 | return innerMostLoop.emitOpError("failed to return additional yields"); |
909 | loops.back() = newInnerMostLoop.value(); |
910 | |
911 | // Make all other loops except the innermost loops yield the values returned |
912 | // by the inner loop. |
913 | for (auto [outerLoop, innerLoop] : |
914 | llvm::zip_equal(loops.drop_back(), loops.drop_front())) { |
915 | // Again assume that all the outer loops are scf.for operations. |
916 | auto outerForLoop = cast<scf::ForOp>(outerLoop); |
917 | auto outerLoopYield = |
918 | cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); |
919 | SmallVector<Value> newYields = |
920 | llvm::to_vector(outerLoopYield.getOperands()); |
921 | ValueRange additionalYields = |
922 | innerLoop->getResults().take_back(newInitValues.size()); |
923 | newYields.append(additionalYields.begin(), additionalYields.end()); |
924 | rewriter.setInsertionPoint(outerLoopYield); |
925 | rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); |
926 | } |
927 | return success(); |
928 | } |
929 | |
930 | /// Implementation of tiling transformation of `op` that implements the |
931 | /// `TilingInterface` using `scf.for` to iterate over the tiles. |
932 | FailureOr<scf::SCFTilingResult> |
933 | mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, |
934 | const scf::SCFTilingOptions &options) { |
935 | if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) { |
936 | return failure(); |
937 | } |
938 | |
939 | OpBuilder::InsertionGuard guard(rewriter); |
940 | rewriter.setInsertionPointAfter(op); |
941 | |
942 | // 1. Get the range of the loops that are represented by the operation. |
943 | SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); |
944 | |
945 | // 2. Materialize the tile sizes and/or number of threads; |
946 | SmallVector<OpFoldResult> tileSizes, numThreads; |
947 | std::tie(args&: tileSizes, args&: numThreads) = |
948 | getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); |
949 | |
950 | // Check if it is safe to tile. This is hold over from previous iterations |
951 | // of tile to for-all. Consider dropping it. |
952 | if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { |
953 | checkSafeToTileToForall(op, tileSizes, numThreads); |
954 | } |
955 | |
956 | // 3. If there is an interchange specified, permute the iteration domain and |
957 | // the tile sizes. |
958 | SmallVector<int64_t> interchangeVector; |
959 | if (!options.interchangeVector.empty()) { |
960 | interchangeVector = fillInterchangeVector(interchangeVector: options.interchangeVector, |
961 | iterationDomainSize: iterationDomain.size()); |
962 | assert(isPermutationVector(interchangeVector) && |
963 | "expected interchange vector to be a permutation"); |
964 | |
965 | applyPermutationToVector(inVec&: iterationDomain, permutation: interchangeVector); |
966 | applyPermutationToVector(inVec&: tileSizes, permutation: interchangeVector); |
967 | if (!numThreads.empty()) |
968 | applyPermutationToVector(inVec&: numThreads, permutation: interchangeVector); |
969 | } |
970 | |
971 | FailureOr<TilingResult> tilingResult; |
972 | // 4. Define the lambda function used later to generate the body of the |
973 | // innermost tiled loop. |
974 | YieldTiledValuesFn innerYieldTiledValuesFn = |
975 | [&](RewriterBase &rewriter, Location loc, ValueRange ivs, |
976 | ValueRange regionIterArgs, SmallVector<Value> &tiledResults, |
977 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
978 | SmallVector<SmallVector<OpFoldResult>> &resultSizes) |
979 | -> LogicalResult { |
980 | // 4a. Compute the `offsets` and `sizes` to use for tiling. |
981 | SmallVector<OpFoldResult> offsets, sizes; |
982 | std::tie(args&: offsets, args&: sizes) = getTileOffsetAndSizes( |
983 | rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); |
984 | |
985 | // 4b. If interchange was provided, apply inverse of the interchange |
986 | // to get back the offsets/sizes in the order to be specified. |
987 | if (!interchangeVector.empty()) { |
988 | auto inversePermutation = invertPermutationVector(permutation: interchangeVector); |
989 | applyPermutationToVector(inVec&: offsets, permutation: inversePermutation); |
990 | applyPermutationToVector(inVec&: sizes, permutation: inversePermutation); |
991 | } |
992 | |
993 | // 5. Generate the tiled implementation within the inner most loop. |
994 | |
995 | // 5a. Clone the operation within the loop body. |
996 | auto clonedOp = cast<TilingInterface>( |
997 | cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); |
998 | |
999 | // 5b. Early return cloned op if tiling is not happening. We can not |
1000 | // return the original op because it could lead to `rewriter.replaceOp(op, |
1001 | // op->getResults())` and users would get crash. |
1002 | if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) { |
1003 | tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); |
1004 | tilingResult = |
1005 | TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), |
1006 | /*generatedSlices=*/{}}; |
1007 | return success(); |
1008 | } |
1009 | |
1010 | // 5c. Tile the cloned operation. |
1011 | tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs, |
1012 | offsets, sizes, options); |
1013 | if (failed(Result: tilingResult)) { |
1014 | rewriter.eraseOp(op: clonedOp); |
1015 | return op.emitOpError("faild to tile operation"); |
1016 | } |
1017 | |
1018 | // 5d. Delete the cloned operation. |
1019 | rewriter.eraseOp(op: clonedOp); |
1020 | |
1021 | // 5e. Compute the offsets at which the result values are to be inserted |
1022 | // back into its destinations. |
1023 | for (auto [index, tiledValue] : |
1024 | llvm::enumerate(First&: tilingResult->tiledValues)) { |
1025 | tiledResults.push_back(Elt: tiledValue); |
1026 | SmallVector<OpFoldResult> resultOffset, resultSize; |
1027 | if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets, |
1028 | sizes, resultOffset, resultSize, |
1029 | options))) { |
1030 | for (auto op : tilingResult->tiledOps) { |
1031 | rewriter.eraseOp(op); |
1032 | } |
1033 | return rewriter.notifyMatchFailure( |
1034 | op, "failed to get slice of result produced"); |
1035 | } |
1036 | resultOffsets.emplace_back(Args: std::move(resultOffset)); |
1037 | resultSizes.emplace_back(Args: std::move(resultSize)); |
1038 | } |
1039 | |
1040 | return success(); |
1041 | }; |
1042 | |
1043 | // 6. Find the destination tensors to use for the operation. |
1044 | FailureOr<SmallVector<Value>> maybeInits = |
1045 | createInitialTensorsForTiling(rewriter, op, tileSizes, options); |
1046 | if (failed(Result: maybeInits)) { |
1047 | return rewriter.notifyMatchFailure( |
1048 | op, "unable to create initial tensors for tiling"); |
1049 | } |
1050 | SmallVector<Value> &initTensors = maybeInits.value(); |
1051 | |
1052 | // 7. Generate the tiled loops nest using the callback defined above. |
1053 | SmallVector<LoopLikeOpInterface> loops; |
1054 | if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, |
1055 | tileSizes, numThreads, initTensors, |
1056 | innerYieldTiledValuesFn, loops))) |
1057 | return op.emitOpError("failed to generate tiling loops"); |
1058 | assert(succeeded(tilingResult) && |
1059 | "expected tiling result to be computed after loop generation"); |
1060 | |
1061 | if (loops.empty()) { |
1062 | // If loops are empty, the tiled op is used as the replacement for the |
1063 | // untiled op. |
1064 | return scf::SCFTilingResult{tilingResult->tiledOps, |
1065 | initTensors, |
1066 | loops, |
1067 | tilingResult->tiledValues, |
1068 | tilingResult->generatedSlices, |
1069 | {}}; |
1070 | } |
1071 | |
1072 | auto loopResults = llvm::map_to_vector(loops.front()->getResults(), |
1073 | [](OpResult r) -> Value { return r; }); |
1074 | |
1075 | // For the full reduction case, there is nothing more to do. |
1076 | if (options.reductionStrategy == |
1077 | scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) { |
1078 | return scf::SCFTilingResult{ |
1079 | tilingResult->tiledOps, initTensors, loops, loopResults, |
1080 | tilingResult->generatedSlices, {}}; |
1081 | } |
1082 | |
1083 | // The results of the loop needs to be merged. |
1084 | FailureOr<MergeResult> mergeResult = |
1085 | mergeTilingResults(rewriter, op, loopResults, options); |
1086 | if (failed(Result: mergeResult)) { |
1087 | return rewriter.notifyMatchFailure( |
1088 | op, "Failed to merge partial results from tiling"); |
1089 | } |
1090 | return scf::SCFTilingResult{tilingResult->tiledOps, |
1091 | initTensors, |
1092 | loops, |
1093 | mergeResult->replacements, |
1094 | tilingResult->generatedSlices, |
1095 | mergeResult->mergeOps}; |
1096 | } |
1097 | |
1098 | FailureOr<scf::SCFTilingResult> |
1099 | mlir::scf::tileReductionUsingScf(RewriterBase &b, |
1100 | PartialReductionOpInterface op, |
1101 | ArrayRef<OpFoldResult> tileSize) { |
1102 | scf::SCFTilingOptions options; |
1103 | options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); |
1104 | options.setReductionTilingStrategy( |
1105 | scf::SCFTilingOptions::ReductionTilingStrategy:: |
1106 | PartialReductionOuterReduction); |
1107 | options.setTileSizes(tileSize); |
1108 | return tileUsingSCF(b, op, options); |
1109 | } |
1110 | |
1111 | //===----------------------------------------------------------------------===// |
1112 | // tileConsumerAndFuseProducersUsingSCF implementation. |
1113 | //===----------------------------------------------------------------------===// |
1114 | |
1115 | /// Return the untiled producer whose slice is used in a tiled consumer. The |
1116 | /// method traverses the tile loop nest (`loops`) if needed, and returns the |
1117 | /// `iter_args` of the outer most that is encountered. Traversing the |
1118 | /// iter_args indicates that this is a destination operand of the consumer. If |
1119 | /// there was no loop traversal needed, the second value of the returned tuple |
1120 | /// is empty. |
1121 | static std::tuple<OpResult, std::optional<OpOperand *>> |
1122 | getUntiledProducerFromSliceSource(OpOperand *source, |
1123 | ArrayRef<LoopLikeOpInterface> loops) { |
1124 | std::optional<OpOperand *> destinationIterArg; |
1125 | assert(!loops.empty() && "expected non empty loops container"); |
1126 | auto loopIt = loops.rbegin(); |
1127 | while (loopIt != loops.rend() && isa<BlockArgument>(Val: source->get())) { |
1128 | auto iterArg = cast<BlockArgument>(Val: source->get()); |
1129 | auto loop = *loopIt; |
1130 | if (iterArg.getOwner()->getParentOp() != loop) |
1131 | break; |
1132 | source = loop.getTiedLoopInit(iterArg); |
1133 | loopIt++; |
1134 | } |
1135 | if (loopIt == loops.rend()) |
1136 | destinationIterArg = source; |
1137 | return {dyn_cast<OpResult>(Val: source->get()), destinationIterArg}; |
1138 | } |
1139 | |
1140 | /// Implementation of fusing producer of a single slice by computing the |
1141 | /// slice of the producer in-place. |
1142 | std::optional<scf::SCFFuseProducerOfSliceResult> |
1143 | mlir::scf::tileAndFuseProducerOfSlice( |
1144 | RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, |
1145 | MutableArrayRef<LoopLikeOpInterface> loops) { |
1146 | // 1. Get the producer of the source (potentially walking through |
1147 | // `iter_args` of nested `scf.for`) |
1148 | auto [fusableProducer, destinationInitArg] = |
1149 | getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), |
1150 | loops); |
1151 | if (!fusableProducer) |
1152 | return std::nullopt; |
1153 | unsigned resultNumber = fusableProducer.getResultNumber(); |
1154 | |
1155 | OpBuilder::InsertionGuard g(rewriter); |
1156 | rewriter.setInsertionPoint(candidateSliceOp); |
1157 | |
1158 | // 2. Clone the fused producer |
1159 | // 2a. Compute the destination operands to use for the cloned operation. |
1160 | SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; |
1161 | Operation *fusableProducerOp = fusableProducer.getOwner(); |
1162 | if (isa<DestinationStyleOpInterface>(fusableProducerOp) && |
1163 | failed(tensor::getOrCreateDestinations( |
1164 | rewriter, fusableProducerOp->getLoc(), fusableProducerOp, |
1165 | origDestinationTensors))) |
1166 | return std::nullopt; |
1167 | |
1168 | clonedOpDestinationTensors = origDestinationTensors; |
1169 | if (destinationInitArg && |
1170 | isa<DestinationStyleOpInterface>(fusableProducerOp)) { |
1171 | // 2b. If the producer is also destination style, then to maintain the |
1172 | // destination passing style, update the destination of the producer to be |
1173 | // the source of the slice. |
1174 | clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); |
1175 | } |
1176 | // 2c. Clone the fused producer. |
1177 | Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( |
1178 | rewriter, op: fusableProducerOp, newDestArgs: clonedOpDestinationTensors); |
1179 | // 2d. Update the source of the candidateSlice to be the cloned producer. |
1180 | // Easier to just clone the slice with different source since |
1181 | // replacements and DCE of cloned ops becomes easier |
1182 | SmallVector<Value> candidateSliceOpOperands = |
1183 | llvm::to_vector(candidateSliceOp->getOperands()); |
1184 | candidateSliceOpOperands[0] = clonedProducerOp->getResult(idx: resultNumber); |
1185 | tensor::ExtractSliceOp clonedCandidateSliceOp = |
1186 | mlir::clone(rewriter, candidateSliceOp, |
1187 | candidateSliceOp->getResultTypes(), candidateSliceOpOperands); |
1188 | |
1189 | // 3. Generate the tiled implementation of the producer of the source |
1190 | FailureOr<TilingResult> tileAndFuseResult = |
1191 | tensor::replaceExtractSliceWithTiledProducer( |
1192 | rewriter, clonedCandidateSliceOp, |
1193 | clonedProducerOp->getResult(idx: resultNumber)); |
1194 | if (failed(Result: tileAndFuseResult)) |
1195 | return std::nullopt; |
1196 | // Note: Do not delete the candidateSliceOp, since its passed in from the |
1197 | // caller. |
1198 | rewriter.replaceAllUsesWith(candidateSliceOp, |
1199 | tileAndFuseResult->tiledValues[0]); |
1200 | rewriter.eraseOp(op: clonedCandidateSliceOp); |
1201 | rewriter.eraseOp(op: clonedProducerOp); |
1202 | |
1203 | // 3. If the slice is for a destination operand, for example, |
1204 | // |
1205 | // ```mlir |
1206 | // %0 = linalg.init |
1207 | // %1 = linalg.fill .. outs(%0 : ) |
1208 | // %2 = scf.for .. iter_args(%arg0 = %1) { |
1209 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { |
1210 | // %4 = tensor.extract_slice %arg1 [..] |
1211 | // .. = linalg.matmul .. outs(%4 : ) |
1212 | // } |
1213 | // } |
1214 | // ``` |
1215 | // |
1216 | // the IR is currently |
1217 | // |
1218 | // ``` |
1219 | // %0 = linalg.init |
1220 | // %1 = linalg.fill |
1221 | // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { |
1222 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { |
1223 | // %4 = tensor.extract_slice %arg1[..] |
1224 | // %5 = linalg.fill .. outs(%4 : ) |
1225 | // .. = linalg.matmul .. outs(%5 : ) |
1226 | // } |
1227 | // } |
1228 | // ``` |
1229 | // |
1230 | // The untiled `linalg.fill` is still used as the `init_value` since it |
1231 | // was originally a destination operand of the untiled `linalg.matmul`. |
1232 | // When fusing an operand that is a destination operand, the iter_arg of |
1233 | // the outer most loop should be changed to use the destination of the |
1234 | // fused operation. With this the IR will be. |
1235 | // |
1236 | // ``` |
1237 | // %0 = linalg.init |
1238 | // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { |
1239 | // %2 = scf.for .. iter_args(%arg1 = %arg0) { |
1240 | // %3 = tensor.extract_slice %arg1[..] |
1241 | // %4 = linalg.fill .. outs(%3 : ) |
1242 | // .. = linalg.matmul .. outs(%4 : ) |
1243 | // } |
1244 | // } |
1245 | // ``` |
1246 | if (destinationInitArg && |
1247 | isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { |
1248 | loops.front() |
1249 | ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] |
1250 | .set(origDestinationTensors[resultNumber]); |
1251 | } |
1252 | return scf::SCFFuseProducerOfSliceResult{ |
1253 | fusableProducer, tileAndFuseResult->tiledValues[0], |
1254 | tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices}; |
1255 | } |
1256 | |
1257 | /// Reconstruct the fused producer from within the tiled-and-fused code. |
1258 | FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer( |
1259 | RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, |
1260 | scf::SCFFuseProducerOfSliceResult fusedProducerInfo, |
1261 | MutableArrayRef<LoopLikeOpInterface> loops, |
1262 | ArrayRef<unsigned> yieldResultNumber) { |
1263 | if (loops.empty()) |
1264 | return success(); |
1265 | |
1266 | Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), |
1267 | *tiledOwner = fusedProducerInfo.tiledOps[0]; |
1268 | |
1269 | Location loc = originalOwner->getLoc(); |
1270 | // a. collect all init Value to be appended |
1271 | SmallVector<unsigned> initNumberList = |
1272 | yieldResultNumber.empty() ? llvm::to_vector(Range: llvm::seq<unsigned>( |
1273 | Begin: 0, End: originalOwner->getNumResults())) |
1274 | : llvm::to_vector(Range&: yieldResultNumber); |
1275 | SmallVector<Value> initValueList; |
1276 | for (const auto &resultNumber : initNumberList) { |
1277 | FailureOr<Value> initValue = tensor::getOrCreateDestination( |
1278 | b&: rewriter, loc, opResult: originalOwner->getResult(idx: resultNumber)); |
1279 | if (succeeded(Result: initValue)) { |
1280 | initValueList.push_back(Elt: initValue.value()); |
1281 | } else { |
1282 | return failure(); |
1283 | } |
1284 | } |
1285 | |
1286 | SmallVector<Operation *> generatedSlices; |
1287 | YieldTiledValuesFn newYieldValuesFn = |
1288 | [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, |
1289 | ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, |
1290 | SmallVector<SmallVector<OpFoldResult>> &tiledOffset, |
1291 | SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { |
1292 | OpBuilder::InsertionGuard g(innerRewriter); |
1293 | |
1294 | // get sliceOp tile information |
1295 | SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(), |
1296 | sliceSizes = sliceOp.getMixedSizes(); |
1297 | |
1298 | // expect all strides of sliceOp being 1 |
1299 | if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) |
1300 | return failure(); |
1301 | |
1302 | unsigned sliceResultNumber = |
1303 | fusedProducerInfo.origProducer.getResultNumber(); |
1304 | |
1305 | auto tilableOp = cast<TilingInterface>(originalOwner); |
1306 | // b. get iterDomain Offset and Sizes based on sliceOp tile |
1307 | SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes; |
1308 | // skip tensor.pack/unpack/pad, which expects single opResult |
1309 | if (tilableOp->getNumResults() > 1 && |
1310 | failed(tilableOp.getIterationDomainTileFromResultTile( |
1311 | rewriter, sliceResultNumber, sliceOffset, sliceSizes, |
1312 | iterDomainOffset, iterDomainSizes))) { |
1313 | // In theory, it is unnecessary to raise an error here. Actually |
1314 | // although it fails to reconstruct the result tensor, it should not |
1315 | // broke current fusion anyway. The reason why we must return failure |
1316 | // currently is that the callback function `newYieldValuesFn` will be |
1317 | // called after new init operand(s) has already been appended. It will |
1318 | // take more refactoring to make sure the init operands are added |
1319 | // consistently in the future. For more details, please refer to: |
1320 | // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 |
1321 | return failure(); |
1322 | } |
1323 | |
1324 | // c. calculate offsets and sizes info of all OpResults respectively based |
1325 | // on iteration Domain Tile |
1326 | SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList; |
1327 | for (const auto &resultNumber : initNumberList) { |
1328 | if (resultNumber == sliceResultNumber) { |
1329 | offsetList.push_back(Elt: sliceOffset); |
1330 | sizesList.push_back(Elt: sliceSizes); |
1331 | } else { |
1332 | assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); |
1333 | // infer result tile according to the iteration domain tile |
1334 | SmallVector<OpFoldResult> offset, sizes; |
1335 | if (failed(tilableOp.getResultTilePosition( |
1336 | rewriter, resultNumber, iterDomainOffset, iterDomainSizes, |
1337 | offset, sizes))) { |
1338 | return failure(); |
1339 | } |
1340 | offsetList.push_back(Elt: offset); |
1341 | sizesList.push_back(Elt: sizes); |
1342 | } |
1343 | } |
1344 | |
1345 | // d. create `extract_slice` for `iter_args` for DPS operation if |
1346 | // necessary |
1347 | if (auto tiledDestStyleOp = |
1348 | dyn_cast<DestinationStyleOpInterface>(tiledOwner)) { |
1349 | rewriter.setInsertionPoint(tiledDestStyleOp); |
1350 | for (const auto &&[index, newRegionArg] : |
1351 | llvm::enumerate(First&: newRegionIterArgs)) { |
1352 | auto destSlice = rewriter.create<tensor::ExtractSliceOp>( |
1353 | loc, newRegionArg, offsetList[index], sizesList[index], |
1354 | SmallVector<OpFoldResult>(offsetList[index].size(), |
1355 | rewriter.getIndexAttr(1))); |
1356 | generatedSlices.push_back(Elt: destSlice); |
1357 | unsigned resultNumber = initNumberList[index]; |
1358 | rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { |
1359 | tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); |
1360 | }); |
1361 | } |
1362 | } |
1363 | |
1364 | // e. prepare tiled offset and sizes for later `insert_slice` creation by |
1365 | // caller |
1366 | Block *block = rewriter.getInsertionPoint()->getBlock(); |
1367 | rewriter.setInsertionPoint(block->getTerminator()); |
1368 | for (const auto &&[index, resultNumber] : llvm::enumerate(First&: initNumberList)) { |
1369 | tiledResult.push_back(Elt: tiledOwner->getResult(idx: resultNumber)); |
1370 | tiledOffset.emplace_back(Args&: offsetList[index]); |
1371 | tiledSizes.emplace_back(Args&: sizesList[index]); |
1372 | } |
1373 | return success(); |
1374 | }; |
1375 | |
1376 | if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList, |
1377 | newYieldValuesFn))) { |
1378 | return failure(); |
1379 | } |
1380 | return generatedSlices; |
1381 | } |
1382 | |
1383 | namespace { |
1384 | |
1385 | //===----------------------------------------------------------------------===// |
1386 | // SliceTrackingListener |
1387 | //===----------------------------------------------------------------------===// |
1388 | |
1389 | /// This class is a listener for tracking the insertion and removal of |
1390 | /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy |
1391 | /// fusion algorithm to apply cleanup patterns in between fusion steps. |
1392 | class SliceTrackingListener : public RewriterBase::Listener { |
1393 | public: |
1394 | explicit SliceTrackingListener( |
1395 | std::optional<FrozenRewritePatternSet> patterns); |
1396 | SliceTrackingListener() = default; |
1397 | |
1398 | /// Adds the given list of operations to the worklist, and if present, |
1399 | /// applies the list of `patterns` to the newly added operations. This only |
1400 | /// processes the given operations and any newly inserted ones by the |
1401 | /// pattern set. |
1402 | LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps); |
1403 | |
1404 | /// Add to the new operation worklist if it is an extract_slice. |
1405 | void notifyOperationInserted(Operation *op, |
1406 | OpBuilder::InsertPoint previous) override; |
1407 | |
1408 | /// Shared helper for operation removal from the worklist. |
1409 | void removeOp(Operation *op); |
1410 | |
1411 | /// Remove the operation from the worklist. |
1412 | void notifyOperationErased(Operation *op) override; |
1413 | |
1414 | /// Remove the operation from the worklist. |
1415 | void notifyOperationReplaced(Operation *op, ValueRange replacement) override; |
1416 | |
1417 | /// The worklist for this transformation keeps track of the slices to visit |
1418 | /// next for fusion. |
1419 | std::deque<tensor::ExtractSliceOp> worklist; |
1420 | |
1421 | private: |
1422 | /// Optional pattern set to apply when adding new operations to the |
1423 | /// worklist. |
1424 | std::optional<FrozenRewritePatternSet> patterns = std::nullopt; |
1425 | }; |
1426 | |
1427 | SliceTrackingListener::SliceTrackingListener( |
1428 | std::optional<FrozenRewritePatternSet> p) { |
1429 | patterns = std::move(p); |
1430 | } |
1431 | |
1432 | LogicalResult |
1433 | SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) { |
1434 | for (Operation *op : ops) { |
1435 | if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op)) |
1436 | worklist.push_back(slice); |
1437 | } |
1438 | |
1439 | if (!patterns) |
1440 | return success(); |
1441 | |
1442 | return applyOpPatternsGreedily( |
1443 | ops, patterns: patterns.value(), |
1444 | config: GreedyRewriteConfig().setListener(this).setStrictness( |
1445 | GreedyRewriteStrictness::ExistingAndNewOps)); |
1446 | } |
1447 | |
1448 | void SliceTrackingListener::notifyOperationInserted( |
1449 | Operation *op, OpBuilder::InsertPoint previous) { |
1450 | auto slice = dyn_cast<tensor::ExtractSliceOp>(op); |
1451 | if (!slice) |
1452 | return; |
1453 | worklist.push_back(slice); |
1454 | } |
1455 | |
1456 | // Scan the worklist for the given op and remove it if present. The |
1457 | // expectation is for the worklist to be small and for removal to be |
1458 | // relatively rare. |
1459 | void SliceTrackingListener::removeOp(Operation *op) { |
1460 | if (!isa<tensor::ExtractSliceOp>(op)) |
1461 | return; |
1462 | auto iter = worklist.begin(); |
1463 | while (iter != worklist.end()) { |
1464 | if (*iter == op) |
1465 | break; |
1466 | iter++; |
1467 | } |
1468 | if (iter == worklist.end()) |
1469 | return; |
1470 | |
1471 | worklist.erase(iter); |
1472 | } |
1473 | |
1474 | void SliceTrackingListener::notifyOperationErased(Operation *op) { |
1475 | removeOp(op); |
1476 | } |
1477 | |
1478 | void SliceTrackingListener::notifyOperationReplaced(Operation *op, |
1479 | ValueRange replacement) { |
1480 | removeOp(op); |
1481 | } |
1482 | |
1483 | //===----------------------------------------------------------------------===// |
1484 | // ReplacementListener |
1485 | //===----------------------------------------------------------------------===// |
1486 | |
1487 | /// Listener that tracks updates replacements for values which can be mutated. |
1488 | /// This listener runs on top of the existing listener for the rewriter, |
1489 | /// to make sure external users can still run listeners. |
1490 | class ReplacementListener : public RewriterBase::ForwardingListener { |
1491 | public: |
1492 | ReplacementListener(DenseMap<Value, Value> &replacements, |
1493 | OpBuilder::Listener *listener) |
1494 | : ForwardingListener(listener), replacements(replacements) {} |
1495 | |
1496 | void updateReplacementValues(ValueRange origValues, |
1497 | ValueRange replaceValues) { |
1498 | // This can probably be written better, but just iterates over the map |
1499 | // and the new replacements for now. |
1500 | for (auto &[key, val] : replacements) { |
1501 | for (auto [orig, replace] : llvm::zip_equal(t&: origValues, u&: replaceValues)) { |
1502 | if (val == orig) { |
1503 | val = replace; |
1504 | } |
1505 | } |
1506 | } |
1507 | } |
1508 | |
1509 | void notifyOperationReplaced(Operation *op, Operation *newOp) override { |
1510 | ForwardingListener::notifyOperationReplaced(op, newOp); |
1511 | updateReplacementValues(origValues: op->getResults(), replaceValues: newOp->getResults()); |
1512 | } |
1513 | |
1514 | void notifyOperationReplaced(Operation *op, ValueRange values) override { |
1515 | ForwardingListener::notifyOperationReplaced(op, replacement: values); |
1516 | updateReplacementValues(origValues: op->getResults(), replaceValues: values); |
1517 | } |
1518 | |
1519 | private: |
1520 | DenseMap<Value, Value> &replacements; |
1521 | }; |
1522 | |
1523 | } // namespace |
1524 | |
1525 | /// Implementation of tile consumer and fuse producer greedily. |
1526 | FailureOr<scf::SCFTileAndFuseResult> |
1527 | mlir::scf::tileConsumerAndFuseProducersUsingSCF( |
1528 | RewriterBase &rewriter, TilingInterface consumer, |
1529 | const scf::SCFTileAndFuseOptions &options) { |
1530 | // This transformation is only valid for ops that return values (i.e. not |
1531 | // valid to use with operations that have memref operands). |
1532 | if (!consumer->getNumResults()) { |
1533 | return rewriter.notifyMatchFailure( |
1534 | consumer, "invalid pattern for op with no results"); |
1535 | } |
1536 | |
1537 | // 1. First tile the consumer. |
1538 | SetVector<Operation *> fusedProducers, tiledAndFusedOps; |
1539 | |
1540 | FailureOr<scf::SCFTilingResult> tilingResult = |
1541 | tileUsingSCF(rewriter, consumer, options.tilingOptions); |
1542 | |
1543 | if (failed(Result: tilingResult)) |
1544 | return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); |
1545 | tiledAndFusedOps.insert_range(tilingResult->tiledOps); |
1546 | |
1547 | DenseMap<Value, Value> replacements; |
1548 | for (auto [origVal, replacement] : |
1549 | llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { |
1550 | replacements[origVal] = replacement; |
1551 | } |
1552 | |
1553 | // If there are no loops generated, fusion is immaterial. |
1554 | auto &loops = tilingResult->loops; |
1555 | if (loops.empty()) { |
1556 | return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, |
1557 | replacements}; |
1558 | } |
1559 | |
1560 | // Since the loop gets potentially replaced during fusion, we need to track |
1561 | // the mutation of replacement values. To do this, we attach a listener to |
1562 | // update the replacements as they happen. |
1563 | OpBuilder::Listener *previousListener = rewriter.getListener(); |
1564 | auto resetListener = |
1565 | llvm::make_scope_exit(F: [&]() { rewriter.setListener(previousListener); }); |
1566 | ReplacementListener replaceListener(replacements, previousListener); |
1567 | rewriter.setListener(&replaceListener); |
1568 | |
1569 | // 2. Typically, the operands of the tiled operation are slices of the |
1570 | // operands of the untiled operation. These are expressed in IR using |
1571 | // `tensor.extract_slice` operations with source being the operands of |
1572 | // the untiled operation. Create a worklist of these |
1573 | // `tensor.extract_slice` operations. If the producers of the source of |
1574 | // the `tensor.extract_slice` can be tiled such that the tiled value is |
1575 | // generated in-place, that effectively tiles + fuses the operations. |
1576 | struct WorklistItem { |
1577 | tensor::ExtractSliceOp candidateSlice; |
1578 | SCFTileAndFuseOptions::ControlFnResult controlFnResult; |
1579 | }; |
1580 | |
1581 | SliceTrackingListener sliceTracker = |
1582 | SliceTrackingListener(options.cleanupPatterns); |
1583 | |
1584 | if (failed( |
1585 | sliceTracker.insertAndApplyPatterns(ops: tilingResult->generatedSlices))) { |
1586 | return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); |
1587 | } |
1588 | OpBuilder::InsertionGuard g(rewriter); |
1589 | while (!sliceTracker.worklist.empty()) { |
1590 | auto candidateSlice = sliceTracker.worklist.front(); |
1591 | sliceTracker.worklist.pop_front(); |
1592 | |
1593 | auto [fusableProducer, destinationInitArg] = |
1594 | getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), |
1595 | loops); |
1596 | if (!fusableProducer) |
1597 | continue; |
1598 | |
1599 | std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult = |
1600 | options.fusionControlFn(candidateSlice, fusableProducer, |
1601 | destinationInitArg.has_value()); |
1602 | if (!controlFnResult) |
1603 | continue; |
1604 | |
1605 | WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; |
1606 | |
1607 | // The operands of the fused producer might themselved be slices of |
1608 | // values produced by operations that implement the `TilingInterface`. |
1609 | // Add these operations to the worklist. |
1610 | std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = |
1611 | tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice, |
1612 | loops); |
1613 | if (!fusedResult) |
1614 | continue; |
1615 | |
1616 | SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices; |
1617 | |
1618 | if (worklistItem.controlFnResult.yieldProducerReplacement) { |
1619 | // Reconstruct and yield all opResult of fusableProducerOp by default. |
1620 | // The caller can specific which one to yield by designating optional |
1621 | // argument named `yieldResultNumber` of |
1622 | // `yieldReplacementForFusedProducer`. |
1623 | Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); |
1624 | FailureOr<SmallVector<Operation *>> newSlices = |
1625 | yieldReplacementForFusedProducer(rewriter, |
1626 | worklistItem.candidateSlice, |
1627 | fusedResult.value(), loops); |
1628 | if (failed(Result: newSlices)) { |
1629 | return rewriter.notifyMatchFailure( |
1630 | arg&: fusableProducerOp, msg: "failed to replacement value for this " |
1631 | "operation from within the tiled loop"); |
1632 | } |
1633 | worklistCandidates.append(RHS: newSlices.value()); |
1634 | for (auto [index, result] : |
1635 | llvm::enumerate(First: fusableProducerOp->getResults())) { |
1636 | replacements[result] = loops.front()->getResult( |
1637 | loops.front()->getNumResults() - |
1638 | fusableProducerOp->getNumResults() + index); |
1639 | } |
1640 | } |
1641 | if (Operation *tiledAndFusedOp = |
1642 | fusedResult->tiledAndFusedProducer.getDefiningOp()) { |
1643 | fusedProducers.insert(X: fusedResult->origProducer.getDefiningOp()); |
1644 | tiledAndFusedOps.insert(X: tiledAndFusedOp); |
1645 | } |
1646 | |
1647 | if (failed(Result: sliceTracker.insertAndApplyPatterns(ops: worklistCandidates))) { |
1648 | return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); |
1649 | } |
1650 | } |
1651 | |
1652 | return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, |
1653 | replacements}; |
1654 | } |
1655 | |
1656 | //===----------------------------------------------------------------------===// |
1657 | // tileAndFuseConsumerUsingSCF implementation. |
1658 | //===----------------------------------------------------------------------===// |
1659 | |
1660 | /// A utility function that checks whether the only use of the result of a |
1661 | /// tensor.insert_slice op is in a scf.yield op. |
1662 | static LogicalResult |
1663 | checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { |
1664 | Value result = candidateSliceOp.getResult(); |
1665 | Value::use_range uses = result.getUses(); |
1666 | if (!llvm::hasSingleElement(C&: uses)) { |
1667 | LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); |
1668 | return failure(); |
1669 | } |
1670 | OpOperand &operandUse = (*uses.begin()); |
1671 | Operation *userOp = operandUse.getOwner(); |
1672 | if (!isa<scf::YieldOp>(userOp)) { |
1673 | LLVM_DEBUG(llvm::dbgs() |
1674 | << "Expected scf.yield to be the only user, but got -> " |
1675 | << (*userOp)); |
1676 | return failure(); |
1677 | } |
1678 | if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { |
1679 | LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " |
1680 | "be in the same block\n"); |
1681 | return failure(); |
1682 | } |
1683 | return success(); |
1684 | } |
1685 | |
1686 | /// An utility to get the first user of the given loopOp. If any of user stay |
1687 | /// in different block of loopOp, return failure. |
1688 | static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) { |
1689 | if (!isa<LoopLikeOpInterface>(loopOp)) |
1690 | return failure(); |
1691 | Operation *firstUserOfLoop = nullptr; |
1692 | for (Operation *userOp : loopOp->getUsers()) { |
1693 | // `ParallelInsertSlice` located inside `InParallelOp` has no same parent |
1694 | // block with any other types of operation. Thus, just redirecting to its |
1695 | // parent `InParallelOp`. E.g. |
1696 | // |
1697 | // ``` |
1698 | // %1 = scf.for { |
1699 | // ... |
1700 | // } |
1701 | // %2 = consumerOp ins(%1, ...) |
1702 | // scf.forall.in_parallel { |
1703 | // tensor.parallel_insert_slice %1 |
1704 | // } |
1705 | // ``` |
1706 | // where `InParallelOp` but not `ParallelInsertSlice` stays in the same |
1707 | // same block with `consumerOp`. |
1708 | if (isa<tensor::ParallelInsertSliceOp>(userOp)) |
1709 | userOp = userOp->getParentOfType<scf::InParallelOp>(); |
1710 | |
1711 | if (loopOp->getBlock() != userOp->getBlock()) |
1712 | return failure(); |
1713 | |
1714 | if (!firstUserOfLoop || userOp->isBeforeInBlock(other: firstUserOfLoop)) |
1715 | firstUserOfLoop = userOp; |
1716 | } |
1717 | return firstUserOfLoop; |
1718 | } |
1719 | |
1720 | /// This utility currently checks whether the first userOp of loop is NOT |
1721 | /// before the last defineOp of consumer operand. Because that we need to move |
1722 | /// the whole loop structure right before the `firstUserOfLoop`. This utility |
1723 | /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice |
1724 | /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that: |
1725 | /// |
1726 | /// ``` |
1727 | /// %0 = scf.for() { |
1728 | /// ... |
1729 | /// } |
1730 | /// ... |
1731 | /// %1 = firstUserOfLoop(%0) |
1732 | /// ... |
1733 | /// %2 = lastDefOfConsumerOperand |
1734 | /// ... |
1735 | /// %3 = consumerOp(%2) |
1736 | /// ``` |
1737 | /// |
1738 | /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it |
1739 | /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`, |
1740 | /// a.k.a. use-def chain violation: |
1741 | /// |
1742 | /// ``` |
1743 | /// %0:2 = scf.for() { |
1744 | /// // use before define error |
1745 | /// %3 = tiledConsumerOp(%2) |
1746 | /// } |
1747 | /// %1 = firstUserOfLoop(%0) |
1748 | /// ... |
1749 | /// %2 = lastDefOfConsumerOperand |
1750 | /// ``` |
1751 | /// |
1752 | /// @param loopOp: loop operation |
1753 | /// @param consumerOp: consumer operation |
1754 | /// @param reorderOperations: the flag controls whether to reorder the |
1755 | /// backward slice w.r.t. the defineOp of `consumerOp` operands. |
1756 | /// @return: computed backward slice of consumerOp, but excluding those |
1757 | /// already dominates `firstUserOfLoop`. |
1758 | static FailureOr<llvm::SetVector<Operation *>> |
1759 | checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, |
1760 | bool reorderOperations) { |
1761 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); |
1762 | if (failed(Result: firstUserOfLoop)) |
1763 | return failure(); |
1764 | |
1765 | BackwardSliceOptions options; |
1766 | DominanceInfo dominanceInfo; |
1767 | options.inclusive = true; |
1768 | options.omitBlockArguments = true; |
1769 | bool includeLoopOp = false; |
1770 | options.filter = [&](Operation *op) { |
1771 | if (op == loopOp) { |
1772 | includeLoopOp = true; |
1773 | return false; |
1774 | } |
1775 | // Cut off the slice to not include any operation that already dominates |
1776 | // firstUserOfLoop. |
1777 | return !dominanceInfo.properlyDominates(a: op, b: *firstUserOfLoop); |
1778 | }; |
1779 | llvm::SetVector<Operation *> slice; |
1780 | for (auto operand : consumerOp->getOperands()) { |
1781 | LogicalResult result = getBackwardSlice(root: operand, backwardSlice: &slice, options); |
1782 | assert(result.succeeded() && "expected a backward slice"); |
1783 | (void)result; |
1784 | } |
1785 | |
1786 | if (!slice.empty()) { |
1787 | // If consumerOp has one producer, which is also the user of loopOp. |
1788 | // E.g. |
1789 | // ``` |
1790 | // %0 = %loopOp |
1791 | // %1 = consumerOp1 ins(%0) |
1792 | // %2 = consumerOp2 ins(%0, %1) |
1793 | // ``` |
1794 | // We can not fuse consumerOp2 into loopOp due to UD chain, unless |
1795 | // consumerOp1 has already been fused into loopOp before. |
1796 | if (includeLoopOp || !reorderOperations) |
1797 | return failure(); |
1798 | } |
1799 | |
1800 | return slice; |
1801 | } |
1802 | |
1803 | /// Fetches the OpOperand of the first valid user (and use) of the value `val` |
1804 | /// which implements `TilingInterface` and `DestinationStyleOpInterface`. |
1805 | /// Returns failure otherwise. |
1806 | static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter, |
1807 | Operation *loopOp, |
1808 | unsigned resultNumber) { |
1809 | if (!isa<LoopLikeOpInterface>(loopOp)) |
1810 | return failure(); |
1811 | Value val = loopOp->getResult(idx: resultNumber); |
1812 | Block *loopBlock = loopOp->getBlock(); |
1813 | for (OpOperand &opOperand : val.getUses()) { |
1814 | Operation *consumerOp = opOperand.getOwner(); |
1815 | // Step 1. Check if the user is tilable. |
1816 | if (!isa<TilingInterface>(consumerOp) || |
1817 | !isa<DestinationStyleOpInterface>(consumerOp)) { |
1818 | // TODO: We have to init result of consumer before scf.for, use |
1819 | // DestinationStyleOpInterface to get result shape from init for now. |
1820 | // Add support for other op such as op has InferTypeOpInterface. |
1821 | continue; |
1822 | } |
1823 | // Step 2. Check if user stay in the same block. |
1824 | if (loopBlock != consumerOp->getBlock()) |
1825 | continue; |
1826 | // Step 3. Check if user has succeeding user. Otherwise, it usually |
1827 | // represents already tiled. |
1828 | if (consumerOp->use_empty()) |
1829 | continue; |
1830 | // Step 4. Check assumption for loop with `reorderOperations` enabled. |
1831 | FailureOr<llvm::SetVector<Operation *>> slice = |
1832 | checkAssumptionForLoop(loopOp, consumerOp, reorderOperations: true); |
1833 | if (failed(Result: slice)) |
1834 | continue; |
1835 | // Step 5. If backward sice is not empty, move them before |
1836 | // firstUserOfLoop. |
1837 | if (!slice->empty()) { |
1838 | mlir::topologicalSort(toSort: *slice); |
1839 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); |
1840 | assert(succeeded(firstUserOfLoop) && "First user of loop is not found"); |
1841 | for (auto op : *slice) { |
1842 | rewriter.moveOpBefore(op, existingOp: *firstUserOfLoop); |
1843 | } |
1844 | } |
1845 | return &opOperand; |
1846 | } |
1847 | return failure(); |
1848 | } |
1849 | |
1850 | /// Check that the loop is perfectly nested. |
1851 | /// The loops are expected to be ordered from outer most to inner most. |
1852 | /// For example: |
1853 | /// ``` |
1854 | /// %0 = scf.for() |
1855 | /// %1 = scf.for() |
1856 | /// %2 = scf.for() |
1857 | /// %3 = ... |
1858 | /// yield %3 |
1859 | /// yield %2 |
1860 | /// yield %1 |
1861 | /// ``` |
1862 | /// Here loops should be [%0, %1]. |
1863 | static bool |
1864 | isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) { |
1865 | assert(!loops.empty() && "unexpected empty loop nest"); |
1866 | if (loops.size() == 1) { |
1867 | return isa_and_nonnull<scf::ForOp>(loops.front().getOperation()); |
1868 | } |
1869 | for (auto [outerLoop, innerLoop] : |
1870 | llvm::zip_equal(loops.drop_back(), loops.drop_front())) { |
1871 | auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation()); |
1872 | auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation()); |
1873 | if (!outerFor || !innerFor) { |
1874 | return false; |
1875 | } |
1876 | auto outerBBArgs = outerFor.getRegionIterArgs(); |
1877 | auto innerIterArgs = innerFor.getInitArgs(); |
1878 | if (outerBBArgs.size() != innerIterArgs.size()) { |
1879 | return false; |
1880 | } |
1881 | |
1882 | for (auto [outerBBArg, innerIterArg] : |
1883 | llvm::zip_equal(outerBBArgs, innerIterArgs)) { |
1884 | if (!llvm::hasSingleElement(outerBBArg.getUses()) || |
1885 | innerIterArg != outerBBArg) { |
1886 | return false; |
1887 | } |
1888 | } |
1889 | |
1890 | ValueRange outerYields = |
1891 | cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands(); |
1892 | ValueRange innerResults = innerFor.getResults(); |
1893 | if (outerYields.size() != innerResults.size()) { |
1894 | return false; |
1895 | } |
1896 | for (auto [outerYield, innerResult] : |
1897 | llvm::zip_equal(outerYields, innerResults)) { |
1898 | if (!llvm::hasSingleElement(innerResult.getUses()) || |
1899 | outerYield != innerResult) { |
1900 | return false; |
1901 | } |
1902 | } |
1903 | } |
1904 | return true; |
1905 | } |
1906 | |
1907 | /// Fetch the untiled consumer of the outermost scf.for's result which is |
1908 | /// yielded by a tensor.insert_slice from the innermost scf.for. This function |
1909 | /// makes the following assumptions : |
1910 | /// 1. tensor.insert_slice has scf.yield as its only user. |
1911 | /// 2. scf.for's corresponding result has only one use. |
1912 | /// 3. The `loops` passed in are perfectly nested `scf.for` operations. |
1913 | static FailureOr<OpOperand *> |
1914 | getUntiledConsumerFromSlice(RewriterBase &rewriter, |
1915 | tensor::InsertSliceOp candidateSliceOp, |
1916 | MutableArrayRef<LoopLikeOpInterface> loops) { |
1917 | assert(!loops.empty() && "unexpected loops to be empty"); |
1918 | // 1. Expect slice to be part of the body of the inner most loop. |
1919 | Operation *containingOp = candidateSliceOp->getParentOp(); |
1920 | if (containingOp != loops.back()) { |
1921 | return rewriter.notifyMatchFailure( |
1922 | candidateSliceOp, |
1923 | "expected slice to be within body of inner-most loop"); |
1924 | } |
1925 | |
1926 | // 2. Check that the loop is perfectly nested. |
1927 | if (!isPerfectlyNestedForLoops(loops)) { |
1928 | return rewriter.notifyMatchFailure( |
1929 | candidateSliceOp, "expected passed loops to be perfectly nested."); |
1930 | } |
1931 | |
1932 | if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) |
1933 | return failure(); |
1934 | Value sliceResult = candidateSliceOp.getResult(); |
1935 | |
1936 | // 3. Fetch the corresponding output. |
1937 | OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); |
1938 | unsigned resultNumber = yieldOpOperand.getOperandNumber(); |
1939 | |
1940 | scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation()); |
1941 | |
1942 | return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); |
1943 | } |
1944 | |
1945 | /// Fetch the first untiled consumer of a scf.forall's result which is yielded |
1946 | /// by a tensor.parallel_insert_slice. |
1947 | static FailureOr<OpOperand *> |
1948 | getUntiledConsumerFromSlice(RewriterBase &rewriter, |
1949 | tensor::ParallelInsertSliceOp candidateSliceOp, |
1950 | MutableArrayRef<LoopLikeOpInterface> loops) { |
1951 | assert(!loops.empty() && "unexpected loops to be empty"); |
1952 | // 1. Check that the surrounding loop is a single scf.forall loop. |
1953 | if (loops.size() != 1) { |
1954 | return rewriter.notifyMatchFailure( |
1955 | candidateSliceOp, "expected single surrounding scf.forall"); |
1956 | } |
1957 | auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation()); |
1958 | if (!forallOp) { |
1959 | return rewriter.notifyMatchFailure( |
1960 | candidateSliceOp, "expected single surrounding scf.forall"); |
1961 | } |
1962 | |
1963 | // 2. Fetch the corresponding output |
1964 | Value sliceDest = candidateSliceOp.getDest(); |
1965 | auto iterArg = dyn_cast<BlockArgument>(Val&: sliceDest); |
1966 | if (!iterArg) |
1967 | return failure(); |
1968 | if (iterArg.getOwner()->getParentOp() != forallOp) |
1969 | return failure(); |
1970 | |
1971 | unsigned resultNumber = |
1972 | forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) |
1973 | .getResultNumber(); |
1974 | |
1975 | return getConsumerFromLoopUses(rewriter, forallOp, resultNumber); |
1976 | } |
1977 | |
1978 | /// A utility to fetch an untiled consumer of |
1979 | /// tensor.insert_slice/tensor.parallel_insert_slice. |
1980 | static FailureOr<OpOperand *> |
1981 | getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp, |
1982 | MutableArrayRef<LoopLikeOpInterface> loops) { |
1983 | assert(!loops.empty() && "unexpected empty loops"); |
1984 | if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { |
1985 | return getUntiledConsumerFromSlice(rewriter, insertSlice, loops); |
1986 | } else if (auto parallelInsertSlice = |
1987 | dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { |
1988 | return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops); |
1989 | } else { |
1990 | return failure(); |
1991 | } |
1992 | } |
1993 | |
1994 | /// Implementation of fusing consumer of a single slice by computing the |
1995 | /// slice of the consumer in-place for scf loop. |
1996 | FailureOr<scf::SCFFuseConsumerOfSliceResult> |
1997 | mlir::scf::tileAndFuseConsumerOfSlice( |
1998 | RewriterBase &rewriter, Operation *candidateSliceOp, |
1999 | MutableArrayRef<LoopLikeOpInterface> loops) { |
2000 | // Return if `loops` is empty, return an error for now. Caller is expected |
2001 | // to handle this case. |
2002 | if (loops.empty()) { |
2003 | return candidateSliceOp->emitOpError( |
2004 | message: "cannot call tile and fuse consumer with an empty loop nest"); |
2005 | } |
2006 | if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( |
2007 | candidateSliceOp)) |
2008 | return failure(); |
2009 | |
2010 | // 1. Get the consumer of scf.for for the result yielded by |
2011 | // tensor.insert_slice/parallel_insert_slice. |
2012 | FailureOr<OpOperand *> maybeConsumerOpOperand = |
2013 | getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops); |
2014 | if (failed(Result: maybeConsumerOpOperand)) { |
2015 | return rewriter.notifyMatchFailure(arg&: candidateSliceOp, |
2016 | msg: "could not fetch consumer to fuse"); |
2017 | } |
2018 | OpOperand *consumerOpOperand = *maybeConsumerOpOperand; |
2019 | Operation *consumerOp = consumerOpOperand->getOwner(); |
2020 | unsigned operandNumber = consumerOpOperand->getOperandNumber(); |
2021 | unsigned resultNumber = 0; |
2022 | if (auto producerResult = dyn_cast<OpResult>(Val: consumerOpOperand->get())) { |
2023 | resultNumber = producerResult.getResultNumber(); |
2024 | } else { |
2025 | return rewriter.notifyMatchFailure( |
2026 | arg&: consumerOp, msg: "consumer op's operand doesn't seem to be an OpResult"); |
2027 | } |
2028 | |
2029 | LoopLikeOpInterface outerMostLoop = loops.front(); |
2030 | LoopLikeOpInterface innerMostLoop = loops.back(); |
2031 | |
2032 | // Check assumption for loop with `reorderOperations` disabled. |
2033 | if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { |
2034 | return rewriter.notifyMatchFailure( |
2035 | outerMostLoop, "the first user of loop should not dominate any define " |
2036 | "of consumer operand(s)"); |
2037 | } |
2038 | |
2039 | OpBuilder::InsertionGuard g(rewriter); |
2040 | |
2041 | // 2. Check consumer is not using scf loop's output as init. |
2042 | auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); |
2043 | if (!dstOp) |
2044 | return rewriter.notifyMatchFailure(arg&: consumerOp, |
2045 | msg: "consumer op is not DPS operation"); |
2046 | SmallVector<Value> dpsInits = |
2047 | llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); |
2048 | if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { |
2049 | return rewriter.notifyMatchFailure( |
2050 | arg&: consumerOp, |
2051 | msg: "consumer op taking the result of scf.for as init is not supported"); |
2052 | } |
2053 | SmallVector<Value> newInits = dpsInits; |
2054 | |
2055 | Location loc = outerMostLoop->getLoc(); |
2056 | |
2057 | // 3. Move the whole loop structure right before firstUserOfLoop, the |
2058 | // dominance should be already ensured by `checkAssumptionForLoop`. |
2059 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop); |
2060 | if (failed(Result: firstUserOfLoop)) { |
2061 | return rewriter.notifyMatchFailure( |
2062 | outerMostLoop, "could not find the first user of outer most loop"); |
2063 | } |
2064 | rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop); |
2065 | |
2066 | // 4. Set insertion point before terminator op of the loop and create a new |
2067 | // tensor.insert_slice. In the scf.for case this is a clone of the |
2068 | // candidateSliceOp whereas in the scf.forall case this is created from the |
2069 | // operands of tensor.parallel_insert_slice. |
2070 | tensor::InsertSliceOp clonedInsertSliceOp; |
2071 | if (auto sliceOp = |
2072 | dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { |
2073 | auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation()); |
2074 | rewriter.setInsertionPoint(newForallOp.getTerminator()); |
2075 | clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( |
2076 | loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), |
2077 | sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); |
2078 | } else { |
2079 | rewriter.setInsertionPoint(candidateSliceOp); |
2080 | clonedInsertSliceOp = |
2081 | cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)); |
2082 | } |
2083 | |
2084 | // 5.a. Clone consumer op. |
2085 | auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(op&: *consumerOp)); |
2086 | |
2087 | // 5.b. Replace all uses of the loop result with the result of the cloned |
2088 | // tensor.insert_slice. |
2089 | OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); |
2090 | rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { |
2091 | operandToReplace.set(clonedInsertSliceOp.getResult()); |
2092 | }); |
2093 | |
2094 | // 6. Perform tiling of the cloned consumer and replace the operand at |
2095 | // `operandNumber` with the source of the cloned tensor.insert_slice op. |
2096 | auto ossSliceOp = |
2097 | cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation()); |
2098 | FailureOr<TilingResult> tileAndFuseResult = |
2099 | tensor::replaceInsertSliceWithTiledConsumer( |
2100 | builder&: rewriter, sliceOp: ossSliceOp, consumerOp&: clonedConsumerOp->getOpOperand(operandNumber)); |
2101 | if (failed(Result: tileAndFuseResult)) { |
2102 | return failure(); |
2103 | } |
2104 | auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]); |
2105 | rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), |
2106 | clonedInsertSliceOp.getSource()); |
2107 | |
2108 | // 7. Reconstruct [nested] loop with new inits. |
2109 | YieldTiledValuesFn newYieldValuesFn = |
2110 | [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, |
2111 | ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, |
2112 | SmallVector<SmallVector<OpFoldResult>> &tiledOffset, |
2113 | SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { |
2114 | OpBuilder::InsertionGuard g(innerRewriter); |
2115 | // 8. Set inner insertPoint right before tiled consumer op. |
2116 | innerRewriter.setInsertionPoint(tiledConsumerOp); |
2117 | |
2118 | SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); |
2119 | SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); |
2120 | SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); |
2121 | |
2122 | // 9. Check all insert stride is 1. |
2123 | if (!llvm::all_of(Range&: strides, P: isOneInteger)) { |
2124 | return rewriter.notifyMatchFailure( |
2125 | arg&: candidateSliceOp, msg: "containingOp's result yield with stride"); |
2126 | } |
2127 | |
2128 | // 10. Try to get iter domain position from input position. Use |
2129 | // clonedConsumerOp instead of tiledConsumerOp, because the iteration |
2130 | // domain may require index computation based on the result size. The |
2131 | // sizes and offsets should be the same either way, but using |
2132 | // tiledConsumerOp could lead to some chained unnecessary extra index |
2133 | // computation. |
2134 | SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; |
2135 | if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( |
2136 | rewriter, operandNumber, offsets, sizes, iterDomainOffsets, |
2137 | iterDomainSizes))) { |
2138 | return rewriter.notifyMatchFailure( |
2139 | clonedConsumerOp, |
2140 | "can't get iter domain position from input position"); |
2141 | } |
2142 | |
2143 | // 11. Try to fetch the offset and size for all results of the cloned |
2144 | // consumer. This would then be used to form the corresponding |
2145 | // tensor.insert_slice/parallel_insert_slice later. |
2146 | unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); |
2147 | SmallVector<SmallVector<OpFoldResult>> resultOffsets( |
2148 | totalNumResultsOfConsumer); |
2149 | SmallVector<SmallVector<OpFoldResult>> resultSizes( |
2150 | totalNumResultsOfConsumer); |
2151 | for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { |
2152 | if (failed(tiledConsumerOp.getResultTilePosition( |
2153 | rewriter, idx, iterDomainOffsets, iterDomainSizes, |
2154 | resultOffsets[idx], resultSizes[idx]))) { |
2155 | return rewriter.notifyMatchFailure( |
2156 | tiledConsumerOp, |
2157 | "can't get result domain position from iter domain position"); |
2158 | } |
2159 | } |
2160 | |
2161 | // 12. Create `extract_slice` for `iter_args` for DPS operation if |
2162 | // necessary. |
2163 | if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>( |
2164 | tiledConsumerOp.getOperation())) { |
2165 | rewriter.setInsertionPoint(tiledDestStyleOp); |
2166 | for (const auto &&[index, newRegionArg] : |
2167 | llvm::enumerate(First&: newRegionIterArgs)) { |
2168 | auto destSlice = rewriter.create<tensor::ExtractSliceOp>( |
2169 | loc, newRegionArg, resultOffsets[index], resultSizes[index], |
2170 | SmallVector<OpFoldResult>(resultOffsets[index].size(), |
2171 | rewriter.getIndexAttr(1))); |
2172 | // Make a copy of index to avoid a capturing structured binding, which |
2173 | // is a C++20 extension. |
2174 | auto dstNumber = index; |
2175 | rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { |
2176 | tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); |
2177 | }); |
2178 | } |
2179 | } |
2180 | |
2181 | // 13. Prepare tiled offset and sizes for later `insert_slice` creation by |
2182 | // caller. |
2183 | Block *block = rewriter.getInsertionPoint()->getBlock(); |
2184 | rewriter.setInsertionPoint(block->getTerminator()); |
2185 | for (const auto &&[index, result] : |
2186 | llvm::enumerate(tiledConsumerOp->getResults())) { |
2187 | tiledResult.push_back(result); |
2188 | tiledOffset.emplace_back(resultOffsets[index]); |
2189 | tiledSizes.emplace_back(resultSizes[index]); |
2190 | } |
2191 | return success(); |
2192 | }; |
2193 | // 14. Add new inits to [nested] loops. |
2194 | if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits, |
2195 | newYieldValuesFn))) { |
2196 | return rewriter.notifyMatchFailure(tiledConsumerOp, |
2197 | "unable to add new inits to nest loop"); |
2198 | } |
2199 | |
2200 | // 15. Replace the result of scf loop and consumer op with new loop's |
2201 | // results. |
2202 | |
2203 | for (auto &&[oldResult, newResult] : |
2204 | llvm::zip(consumerOp->getResults(), |
2205 | loops.front()->getResults().take_back(newInits.size()))) { |
2206 | rewriter.replaceAllUsesWith(oldResult, newResult); |
2207 | } |
2208 | |
2209 | // 16. Need to erase the old scf loop and the cloned consumer op. |
2210 | rewriter.eraseOp(op: clonedConsumerOp); |
2211 | |
2212 | return scf::SCFFuseConsumerOfSliceResult{ |
2213 | .origConsumerOperand: consumerOpOperand, |
2214 | .tiledAndFusedConsumerOperand: &(tileAndFuseResult->tiledOps[0]->getOpOperand(idx: operandNumber)), |
2215 | .tiledOps: tileAndFuseResult->tiledOps}; |
2216 | } |
2217 | |
2218 | //===----------------------------------------------------------------------===// |
2219 | // lowerToLoopsUsingSCFForOp implementation. |
2220 | //===----------------------------------------------------------------------===// |
2221 | |
2222 | FailureOr<SmallVector<scf::ForOp>> |
2223 | mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, |
2224 | TilingInterface op) { |
2225 | // TODO: Handle cases where the op has results if needed. |
2226 | if (op->getNumResults() > 0) { |
2227 | return rewriter.notifyMatchFailure( |
2228 | op, "unable to lower to loops operations with return values"); |
2229 | } |
2230 | |
2231 | SmallVector<Range> domain = op.getIterationDomain(rewriter); |
2232 | SmallVector<Value> ivs; |
2233 | SmallVector<scf::ForOp> loops; |
2234 | Location loc = op.getLoc(); |
2235 | for (auto loopRange : domain) { |
2236 | Value offsetVal = |
2237 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); |
2238 | Value sizeVal = |
2239 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); |
2240 | Value strideVal = |
2241 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); |
2242 | auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, |
2243 | strideVal, ValueRange{}); |
2244 | loops.push_back(loop); |
2245 | ivs.push_back(loop.getInductionVar()); |
2246 | rewriter.setInsertionPoint(loop.getBody()->getTerminator()); |
2247 | } |
2248 | if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { |
2249 | return failure(); |
2250 | } |
2251 | return loops; |
2252 | } |
2253 |
Definitions
- setTileSizes
- setNumThreads
- fillInterchangeVector
- verifyTileSizeOptions
- getUserTileSizesAndNumThreads
- checkSafeToTileToForall
- tileDividesIterationDomain
- getBoundedTileSize
- canOmitTileOffsetInBoundsCheck
- getTileOffsetAndSizes
- getLoopBounds
- cloneOpAndUpdateDestinationArgs
- generateLoopNestUsingForOp
- generateLoopNestUsingForallOp
- generateLoopNest
- createInitialTensorsForTiling
- getTiledImplementation
- getResultTilePosition
- mergeTilingResults
- yieldTiledValuesAndReplaceLoop
- yieldTiledValuesAndReplaceLoop
- addInitOperandsToLoopNest
- tileUsingSCF
- tileReductionUsingScf
- getUntiledProducerFromSliceSource
- tileAndFuseProducerOfSlice
- yieldReplacementForFusedProducer
- SliceTrackingListener
- SliceTrackingListener
- SliceTrackingListener
- insertAndApplyPatterns
- notifyOperationInserted
- removeOp
- notifyOperationErased
- notifyOperationReplaced
- ReplacementListener
- ReplacementListener
- updateReplacementValues
- notifyOperationReplaced
- notifyOperationReplaced
- tileConsumerAndFuseProducersUsingSCF
- checkAssumptionForFusingConsumer
- getFirstUserOfLoop
- checkAssumptionForLoop
- getConsumerFromLoopUses
- isPerfectlyNestedForLoops
- getUntiledConsumerFromSlice
- getUntiledConsumerFromSlice
- getUntiledConsumerFromSlice
- tileAndFuseConsumerOfSlice
Learn to use CMake with our Intro Training
Find out more