1//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the tiling using TilingInterface.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14
15#include "mlir/Analysis/SliceAnalysis.h"
16#include "mlir/Analysis/TopologicalSortUtils.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Arith/Utils/Utils.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/SCF/Utils/Utils.h"
22#include "mlir/Dialect/Tensor/IR/Tensor.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/IR/Dominance.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Interfaces/DestinationStyleOpInterface.h"
28#include "mlir/Interfaces/TilingInterface.h"
29#include "mlir/Rewrite/FrozenRewritePatternSet.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/ADT/ScopeExit.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/Debug.h"
34#include <optional>
35
36#define DEBUG_TYPE "tile-using-interface"
37
38using namespace mlir;
39
40scf::SCFTilingOptions &
41scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
42 assert(!tileSizeComputationFunction && "tile sizes already set");
43 auto tileSizes = llvm::to_vector(Range&: ts);
44 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
45 return tileSizes;
46 };
47 return *this;
48}
49
50scf::SCFTilingOptions &
51scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
52 assert(!numThreadsComputationFunction && "num tiles already set");
53 auto numThreads = llvm::to_vector(Range&: nt);
54 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
55 return numThreads;
56 };
57 return *this;
58}
59
60/// Helper method to adjust the interchange vector to match the iteration
61/// domain.
62static SmallVector<int64_t>
63fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
64 size_t iterationDomainSize) {
65 SmallVector<int64_t> filledVector = llvm::to_vector(Range&: interchangeVector);
66 if (filledVector.size() < iterationDomainSize) {
67 auto range = llvm::seq<int64_t>(Begin: filledVector.size(), End: iterationDomainSize);
68 filledVector.append(in_start: range.begin(), in_end: range.end());
69 }
70 if (filledVector.size() > iterationDomainSize)
71 filledVector.resize(N: iterationDomainSize);
72 return filledVector;
73}
74
75//===----------------------------------------------------------------------===//
76// tileUsingSCF implementation.
77//===----------------------------------------------------------------------===//
78
79/// Verify the tile size options are set in a consistent manner.
80static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc,
81 const scf::SCFTilingOptions &options) {
82 // Specifying number of threads is only supported on `scf.forall` op.
83 if (options.numThreadsComputationFunction &&
84 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
85 return rewriter.notifyMatchFailure(
86 arg&: loc, msg: "number of threads can only by specified when loop type is "
87 "set to use `scf.forall`");
88 }
89
90 // If specified, check that the interchange vector is a permutation.
91 if (!options.interchangeVector.empty()) {
92 if (!isPermutationVector(interchange: options.interchangeVector)) {
93 return rewriter.notifyMatchFailure(
94 arg&: loc, msg: "invalid interchange vector, not a permutation of the entire "
95 "iteration space");
96 }
97 }
98 return success();
99}
100
101/// Method to instantiate the tile sizes and/or number of threads specified
102/// by the user.
103static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
104getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
105 ArrayRef<Range> iterationDomain,
106 const scf::SCFTilingOptions &options) {
107 OpFoldResult zero = rewriter.getIndexAttr(value: 0);
108 SmallVector<OpFoldResult> tileSizes, numThreads;
109 size_t numLoops = iterationDomain.size();
110
111 // Check whether the number of tiles to use is specified.
112 if (options.numThreadsComputationFunction) {
113 numThreads = options.numThreadsComputationFunction(rewriter, op);
114 numThreads.resize(N: numLoops, NV: zero);
115
116 // If the number of tiles is also specified, use that.
117 if (options.tileSizeComputationFunction) {
118 tileSizes = options.tileSizeComputationFunction(rewriter, op);
119 tileSizes.resize(N: numLoops, NV: zero);
120 return {tileSizes, numThreads};
121 }
122
123 // Compute the tile sizes from the iteration domain and number
124 // of tiles as follows
125 // - niters = ceilDiv(ub - lb, step)
126 // - tileSize = ceilDiv(niters, numThreads)
127 AffineExpr s0, s1, s2;
128 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
129 // TODO: The step here is assumed to be 1.
130 AffineExpr numItersExpr = (s1 - s0);
131 AffineExpr tileSizeExpr = numItersExpr.ceilDiv(other: s2);
132 tileSizes.resize(N: numLoops, NV: zero);
133 for (auto [index, range, nt] :
134 llvm::enumerate(First&: iterationDomain, Rest&: numThreads)) {
135 if (isZeroInteger(v: nt))
136 continue;
137
138 tileSizes[index] = affine::makeComposedFoldedAffineApply(
139 b&: rewriter, loc: op.getLoc(), expr: tileSizeExpr, operands: {range.offset, range.size, nt});
140 }
141 tileSizes.resize(N: numLoops, NV: zero);
142 return {tileSizes, numThreads};
143 }
144
145 // Enforce the convention that "tiling by zero"
146 // skips tiling a particular dimension. This convention is significantly
147 // simpler to handle instead of adjusting affine maps to account for missing
148 // dimensions.
149 assert(options.tileSizeComputationFunction &&
150 "expected tile sizes to be specified");
151 tileSizes = options.tileSizeComputationFunction(rewriter, op);
152 tileSizes.resize(N: numLoops, NV: zero);
153
154 return {tileSizes, numThreads};
155}
156
157/// Checks if any of the tiled loops are not parallel.
158static LogicalResult checkTileSizes(TilingInterface op,
159 scf::SCFTilingOptions::LoopType loopType,
160 ReductionTilingStrategy reductionStrategy,
161 ArrayRef<OpFoldResult> tileSizes,
162 ArrayRef<OpFoldResult> numThreads) {
163 auto iterators = op.getLoopIteratorTypes();
164 assert(iterators.size() == tileSizes.size() &&
165 "expected as many tile size values as number of loops");
166 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
167 "when specified, expected number of threads to use for each loop");
168
169 bool isParallelTiling = false;
170 for (auto [index, iterator, tileSize] :
171 llvm::enumerate(First&: iterators, Rest&: tileSizes)) {
172 if (!isConstantIntValue(ofr: tileSize, value: 0)) {
173 isParallelTiling |= iterator == utils::IteratorType::parallel;
174 }
175
176 if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
177 reductionStrategy == ReductionTilingStrategy::FullReduction) {
178 // If num threads is specified, check that it is greater than one only for
179 // parallel dimensions.
180 if (!numThreads.empty()) {
181 if (std::optional<int64_t> constNumThreads =
182 getConstantIntValue(ofr: numThreads[index])) {
183 if (constNumThreads.value() > 1 &&
184 iterator != utils::IteratorType::parallel) {
185 op.emitWarning() << "tiling is not thread safe at axis #" << index;
186 }
187 }
188 continue;
189 }
190
191 if (std::optional<int64_t> constTileSize =
192 getConstantIntValue(ofr: tileSize)) {
193 if (constTileSize.value() > 0 &&
194 iterator != utils::IteratorType::parallel) {
195 op.emitWarning() << "tiling is not thread safe at axis #" << index;
196 }
197 }
198 }
199 }
200
201 if (reductionStrategy != ReductionTilingStrategy::FullReduction) {
202 if (isParallelTiling) {
203 return op->emitOpError(message: "tiling parallel dimensions is not supported with "
204 "partial reduction tiling strategies");
205 }
206 }
207 return success();
208}
209
210/// Get the reduction dims that are tiled. This accounts for reduction dims
211/// that are specified as tiled, but the tile size is 0.
212static SetVector<unsigned>
213getSanitizedReductionDims(ArrayRef<OpFoldResult> tileSizes,
214 const scf::SCFTilingOptions &options) {
215 SetVector<unsigned> reductionDims;
216 for (auto dim : options.reductionDims) {
217 if (isConstantIntValue(ofr: tileSizes[dim], value: 0))
218 continue;
219 reductionDims.insert(X: dim);
220 }
221 return reductionDims;
222}
223
224/// Check if `stride` evenly divides the trip count `size - offset`.
225static bool tileDividesIterationDomain(Range loopRange) {
226 std::optional<int64_t> offsetAsInt = getConstantIntValue(ofr: loopRange.offset);
227 if (!offsetAsInt)
228 return false;
229 std::optional<int64_t> sizeAsInt = getConstantIntValue(ofr: loopRange.size);
230 if (!sizeAsInt)
231 return false;
232 std::optional<int64_t> strideAsInt = getConstantIntValue(ofr: loopRange.stride);
233 if (!strideAsInt)
234 return false;
235 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
236}
237
238/// Returns the bounded tile size given the current `offset`, `loopRange` and
239/// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
240static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
241 Range loopRange, OpFoldResult offset,
242 OpFoldResult tileSize) {
243 std::optional<int64_t> ts = getConstantIntValue(ofr: tileSize);
244 if (ts && ts.value() == 1)
245 return tileSize;
246
247 if (tileDividesIterationDomain(
248 loopRange: Range{.offset: loopRange.offset, .size: loopRange.size, .stride: tileSize}))
249 return tileSize;
250
251 // The tile size to use (to avoid out of bounds access) is minimum of
252 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
253 // loop.
254 AffineExpr s0, s1, d0;
255 bindDims(ctx: b.getContext(), exprs&: d0);
256 bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1);
257 AffineMap minMap = AffineMap::get(dimCount: 1, symbolCount: 2, results: {s0 - d0, s1}, context: b.getContext());
258 Value size = getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange.size);
259 return affine::makeComposedFoldedAffineMin(
260 b, loc, map: minMap, operands: SmallVector<OpFoldResult>{offset, size, tileSize});
261}
262
263/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
264/// than `iterationSize`.
265static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
266 OpFoldResult numThreads,
267 OpFoldResult iterationSize) {
268 std::optional<int64_t> tileSizeConst = getConstantIntValue(ofr: tileSize);
269 std::optional<int64_t> numThreadsConst = getConstantIntValue(ofr: numThreads);
270 std::optional<int64_t> iterSizeConst = getConstantIntValue(ofr: iterationSize);
271 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
272 return false;
273 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
274}
275
276/// Compute the `OpFoldResult`s that represents the multi-dimensional
277/// `offset`s and `size`s of the tile of the iteration space that the
278/// innermost loop body of the generated tiled loops corresponds to.
279static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
280getTileOffsetAndSizes(RewriterBase &rewriter, Location loc,
281 ReductionTilingStrategy strategy, ValueRange ivs,
282 ArrayRef<Range> iterationDomain,
283 ArrayRef<OpFoldResult> tileSizes,
284 ArrayRef<OpFoldResult> numThreads,
285 const llvm::SetVector<unsigned> &reductionDims) {
286 SmallVector<OpFoldResult> offsets, sizes;
287 int materializedLoopNum = 0;
288
289 if (!numThreads.empty()) {
290 AffineExpr d0, d1, s0, s1;
291 AffineExpr offsetExpr, residualTileSizeExpr;
292 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1);
293 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
294 offsetExpr = d0 + d1 * s0;
295 residualTileSizeExpr = s1 - (d0 + d1 * s0);
296
297 for (auto [index, nt, tileSize, loopRange] :
298 llvm::enumerate(First&: numThreads, Rest&: tileSizes, Rest&: iterationDomain)) {
299
300 // Non-tiled cases, set the offset and size to the
301 // `loopRange.offset/size`.
302 if (isZeroInteger(v: nt)) {
303 offsets.push_back(Elt: loopRange.offset);
304 sizes.push_back(Elt: loopRange.size);
305 continue;
306 }
307
308 Value iv = ivs[materializedLoopNum++];
309 OpFoldResult offset = affine::makeComposedFoldedAffineApply(
310 b&: rewriter, loc, expr: offsetExpr,
311 operands: ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
312 OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
313 b&: rewriter, loc, expr: residualTileSizeExpr,
314 operands: {loopRange.offset, nt, tileSize, loopRange.size});
315
316 OpFoldResult size = tileSize;
317 if (!isZeroInteger(v: residualTileSize)) {
318 OpFoldResult sizeMinusOffsetPerThread =
319 affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 - d0,
320 operands: {offset, loopRange.size});
321 size = affine::makeComposedFoldedAffineMin(
322 b&: rewriter, loc,
323 map: AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext()),
324 operands: {sizeMinusOffsetPerThread, tileSize});
325 }
326
327 // Consider the case where the original loop was `[0, 100)`.
328 // If number of threads are `7`, the tile size would be computed as
329 // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
330 // - `offset = 0 + 6 * 15 = 105`
331 // - `tileSize = min(15, 100 - 105) = -5`
332 // To avoid negative tile sizes, we need to do a further
333 // `nonNegativeTileSize = affine.max(0, tileSize)`.
334 // This `max` can be avoided if
335 // `offset + tileSize * (numThreads - 1) < (ub - lb)`
336 if (!canOmitTileOffsetInBoundsCheck(tileSize, numThreads: nt, iterationSize: loopRange.size)) {
337 AffineMap maxMap =
338 AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext());
339 size = affine::makeComposedFoldedAffineMax(
340 b&: rewriter, loc, map: maxMap, operands: {rewriter.getIndexAttr(value: 0), size});
341 }
342
343 offsets.push_back(Elt: offset);
344 sizes.push_back(Elt: size);
345 }
346 return {offsets, sizes};
347 } else {
348 for (auto [tileSize, loopRange] :
349 llvm::zip_equal(t&: tileSizes, u&: iterationDomain)) {
350
351 // Non-tiled cases, set the offset and size to the
352 // `loopRange.offset/size`.
353 if (isZeroInteger(v: tileSize)) {
354 offsets.push_back(Elt: loopRange.offset);
355 sizes.push_back(Elt: loopRange.size);
356 continue;
357 }
358
359 Value iv = ivs[materializedLoopNum++];
360 OpFoldResult offset = getAsOpFoldResult(val: iv);
361 offsets.push_back(Elt: offset);
362 OpFoldResult size =
363 getBoundedTileSize(b&: rewriter, loc, loopRange, offset, tileSize);
364 sizes.push_back(Elt: size);
365 }
366 return {offsets, sizes};
367 }
368}
369
370/// Function to return the bounds of the loops to be generated.
371static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
372 SmallVector<OpFoldResult>>
373getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
374 ArrayRef<OpFoldResult> tileSizes) {
375 SmallVector<OpFoldResult> lbs, ubs, steps;
376 for (auto [loopRange, tileSize] : llvm::zip_equal(t&: loopRanges, u&: tileSizes)) {
377 // No loop if the tile size is 0.
378 if (isZeroInteger(v: tileSize))
379 continue;
380 lbs.push_back(Elt: loopRange.offset);
381 ubs.push_back(Elt: loopRange.size);
382 steps.push_back(Elt: tileSize);
383 }
384 return {lbs, ubs, steps};
385}
386
387/// A function that allows returning additional yielded values during
388/// `yieldTiledValuesAndReplace`.
389/// - `ivs` induction variable for the loop.
390/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
391/// - `tiledValues` the tiled values to return. Must be of same size as
392/// `newbbArgs`, each element of this array is inserted into the corresponding
393/// element in `newbbArgs`.
394/// - `resultOffsets` is of the same size as `tiledValues` and represents
395/// the offsets to use when inserting corresponding element from `tiledValues`
396/// into the element from `newBbArgs`.
397/// - `resultSizes` is of the same size as `tiledValues` and represents
398/// the size of the corresponding element from `tiledValues` inserted into
399/// the element from `newBbArgs`.
400/// In case the method needs to return `failure()` the method is expected
401/// to clean up any inserted operations.
402using YieldTiledValuesFn = std::function<LogicalResult(
403 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
404 SmallVector<Value> &tiledValues,
405 SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
406 SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
407
408/// Clones the operation and updates the destination if the operation
409/// implements the `DestinationStyleOpInterface`.
410static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
411 Operation *op,
412 ValueRange newDestArgs) {
413 Operation *clonedOp = rewriter.clone(op&: *op);
414 if (newDestArgs.empty())
415 return clonedOp;
416 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(Val: clonedOp))
417 destinationStyleOp.getDpsInitsMutable().assign(values: newDestArgs);
418 return clonedOp;
419}
420
421/// Generate the tile-loop nest using `scf.for` operation.
422/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
423/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
424/// - `destinationTensors` are the init values to use for the outer most loop.
425/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
426/// most
427/// loop.
428/// - `loops` is an in-out parameter into which the generated loops are
429/// populated.
430static LogicalResult generateLoopNestUsingForOp(
431 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
432 ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
433 YieldTiledValuesFn yieldTiledValuesFn,
434 SmallVector<LoopLikeOpInterface> &loops) {
435 assert(!loopRanges.empty() && "unexpected empty loop ranges");
436 assert(loopRanges.size() == tileSizes.size() &&
437 "expected as many tile sizes as loop ranges");
438 OpBuilder::InsertionGuard guard(rewriter);
439
440 SmallVector<OpFoldResult> lbs, ubs, steps;
441 std::tie(args&: lbs, args&: ubs, args&: steps) =
442 getLoopBounds(rewriter, loc, loopRanges, tileSizes);
443 SmallVector<Value> lbVals =
444 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: lbs);
445 SmallVector<Value> ubVals =
446 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: ubs);
447 SmallVector<Value> stepVals =
448 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: steps);
449
450 SmallVector<Value> ivs;
451 for (auto [lb, ub, step] : llvm::zip_equal(t&: lbVals, u&: ubVals, args&: stepVals)) {
452 auto loop =
453 rewriter.create<scf::ForOp>(location: loc, args&: lb, args&: ub, args&: step, args&: destinationTensors,
454 args: [](OpBuilder &bodyBuilder, Location bodyLoc,
455 Value iv, ValueRange /*iterArgs*/) {});
456 loops.push_back(Elt: loop);
457 ivs.push_back(Elt: loop.getInductionVar());
458 rewriter.setInsertionPointToEnd(loop.getBody());
459 destinationTensors = loop.getRegionIterArgs();
460 }
461
462 SmallVector<Value> tiledResults;
463 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
464 if (failed(Result: yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
465 tiledResults, resultOffsets, resultSizes))) {
466 return rewriter.notifyMatchFailure(
467 arg&: loc, msg: "failed to generate inner tile loop body");
468 }
469 if (loops.empty())
470 return success();
471
472 assert(tiledResults.size() == destinationTensors.size() &&
473 "Number of results of body should be equal to number of iter args");
474
475 // 6. Yield all the results of the tiled operation.
476 SmallVector<Value> yieldedValues;
477 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
478 llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets,
479 args&: resultSizes)) {
480 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
481 rewriter.getIndexAttr(value: 1));
482 auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
483 location: loc, args&: tiledValue, args&: destinationTensor, args&: resultOffset, args&: resultSize,
484 args&: resultStride);
485 yieldedValues.push_back(Elt: insertSlice);
486 }
487 rewriter.create<scf::YieldOp>(location: loc, args&: yieldedValues);
488
489 // Add the scf.yield operations for all the outer loops.
490 for (auto [outerLoop, innerLoop] :
491 llvm::zip_equal(t: MutableArrayRef(loops).drop_back(),
492 u: MutableArrayRef(loops).drop_front())) {
493 rewriter.setInsertionPointToEnd(
494 cast<scf::ForOp>(Val: outerLoop.getOperation()).getBody());
495 rewriter.create<scf::YieldOp>(location: outerLoop.getLoc(), args: innerLoop->getResults());
496 }
497 return success();
498}
499
500/// Generate the tile-loop nest using `scf.forall` operation.
501/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
502/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
503/// - `destinationTensors` are the init values to use for the outer most loop.
504/// - `mappingVector` is the mapping attributes to use for loop construction.
505/// Can be empty.
506/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
507/// most
508/// loop.
509/// - `loops` is an in-out parameter into which the generated loops are
510/// populated.
511static LogicalResult generateLoopNestUsingForallOp(
512 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
513 ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
514 ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
515 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
516 assert(!loopRanges.empty() && "unexpected empty loop ranges");
517 assert(loopRanges.size() == tileSizes.size() &&
518 "expected as many tile sizes as loop ranges");
519 OpBuilder::InsertionGuard guard(rewriter);
520
521 std::optional<ArrayAttr> mappingAttr;
522 if (!mappingVector.empty())
523 mappingAttr = rewriter.getArrayAttr(value: mappingVector);
524
525 scf::ForallOp forallOp;
526 bool useNumThreads = !numThreads.empty();
527
528 if (useNumThreads) {
529 // Prune the zero numthreads.
530 SmallVector<OpFoldResult> nonZeroNumThreads;
531 for (auto nt : numThreads) {
532 if (isZeroInteger(v: nt))
533 continue;
534 nonZeroNumThreads.push_back(Elt: nt);
535 }
536 forallOp = rewriter.create<scf::ForallOp>(location: loc, args&: nonZeroNumThreads,
537 args&: destinationTensors, args&: mappingAttr);
538 } else {
539 SmallVector<OpFoldResult> lbs, ubs, steps;
540 std::tie(args&: lbs, args&: ubs, args&: steps) =
541 getLoopBounds(rewriter, loc, loopRanges, tileSizes);
542 forallOp = rewriter.create<scf::ForallOp>(location: loc, args&: lbs, args&: ubs, args&: steps,
543 args&: destinationTensors, args&: mappingAttr);
544 }
545 loops.push_back(Elt: forallOp);
546
547 rewriter.setInsertionPoint(forallOp.getTerminator());
548 destinationTensors = forallOp.getRegionOutArgs();
549
550 SmallVector<Value> tiledResults;
551 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
552 if (failed(Result: tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
553 destinationTensors, tiledResults, resultOffsets,
554 resultSizes)))
555 return rewriter.notifyMatchFailure(arg&: loc, msg: "failed to generate loop body");
556
557 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
558 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
559 llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets,
560 args&: resultSizes)) {
561 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
562 rewriter.getIndexAttr(value: 1));
563
564 rewriter.create<tensor::ParallelInsertSliceOp>(
565 location: loc, args&: tiledValue, args&: destinationTensor, args&: resultOffset, args&: resultSize,
566 args&: resultStride);
567 }
568 return success();
569}
570
571/// Generate the tile-loop nest using the loop construct specifed in `options`.
572/// - `options`: Tiling options specified.
573/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
574/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
575/// - `destinationTensors` are the init values to use for the outer most loop.
576/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
577/// most
578/// loop.
579/// - `loops` is an in-out parameter into which the generated loops are
580/// populated.
581static LogicalResult generateLoopNest(
582 RewriterBase &rewriter, Location loc,
583 scf::SCFTilingOptions::LoopType loopType, ArrayRef<Range> loopRanges,
584 ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
585 ValueRange destinationTensors, ArrayRef<Attribute> mappingVector,
586 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
587 // If the tile sizes are all zero, no loops are generated. Just call the
588 // callback function to handle untiled case.
589 if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) {
590 SmallVector<Value> tiledResults;
591 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
592 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
593 tiledResults, resultOffsets, resultSizes);
594 }
595 if (loopType == scf::SCFTilingOptions::LoopType::ForOp) {
596 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
597 destinationTensors, yieldTiledValuesFn: tiledBodyFn, loops);
598 }
599 if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
600 return generateLoopNestUsingForallOp(
601 rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector,
602 destinationTensors, tiledBodyFn, loops);
603 }
604 return rewriter.notifyMatchFailure(arg&: loc, msg: "unhandled loop type");
605}
606
607static FailureOr<SmallVector<Value>> createInitialTensorsForTiling(
608 RewriterBase &rewriter, TilingInterface op,
609 ReductionTilingStrategy reductionStrategy, ArrayRef<Range> iterationDomain,
610 ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
611 const SetVector<unsigned> &reductionDims) {
612 SmallVector<Value> initTensors;
613 Location loc = op->getLoc();
614 if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
615 if (failed(Result: tensor::getOrCreateDestinations(b&: rewriter, loc, op, result&: initTensors)))
616 return failure();
617 return initTensors;
618 }
619
620 auto redOp = dyn_cast<PartialReductionOpInterface>(Val: op.getOperation());
621 if (!redOp) {
622 return op->emitOpError(
623 message: "PartialReductionOuterReduction tiling strategy is only supported for "
624 "operations implementing PartialReductionOpInterface");
625 }
626 SmallVector<OpFoldResult> sizes(iterationDomain.size());
627 AffineExpr s0, s1, s2;
628 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
629 AffineExpr sizeExpr = ((s0 - s1).ceilDiv(other: s2));
630 AffineExpr divExpr = s0.ceilDiv(other: s1);
631 for (auto [index, domain, tileSize] :
632 llvm::enumerate(First&: iterationDomain, Rest&: tileSizes)) {
633 if (!numThreads.empty()) {
634 // Untiled case.
635 if (isConstantIntValue(ofr: numThreads[index], value: 0)) {
636 sizes[index] = affine::makeComposedFoldedAffineApply(
637 b&: rewriter, loc: op.getLoc(), expr: sizeExpr,
638 operands: {domain.size, domain.offset, domain.stride});
639 continue;
640 }
641 sizes[index] = numThreads[index];
642 continue;
643 }
644
645 // Non reduction dimensions/non-tiled dimensions.
646 if (!reductionDims.contains(key: index) || isConstantIntValue(ofr: tileSize, value: 0)) {
647 sizes[index] = affine::makeComposedFoldedAffineApply(
648 b&: rewriter, loc: op.getLoc(), expr: sizeExpr,
649 operands: {domain.size, domain.offset, domain.stride});
650 continue;
651 }
652
653 if (reductionStrategy ==
654 ReductionTilingStrategy::PartialReductionOuterReduction) {
655 sizes[index] = tileSize;
656 continue;
657 }
658
659 assert(reductionStrategy ==
660 ReductionTilingStrategy::PartialReductionOuterParallel);
661 OpFoldResult normalizedRange = affine::makeComposedFoldedAffineApply(
662 b&: rewriter, loc: op.getLoc(), expr: sizeExpr,
663 operands: {domain.size, domain.offset, domain.stride});
664 sizes[index] = affine::makeComposedFoldedAffineApply(
665 b&: rewriter, loc: op.getLoc(), expr: divExpr, operands: {normalizedRange, tileSize});
666 }
667 return redOp.generateInitialTensorForPartialReduction(b&: rewriter, loc, tileSizes: sizes,
668 reductionDims);
669}
670
671/// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel`
672/// the `PartialReductionOpInterface` methods need the index of the parallel
673/// split reduction being executed.
674static SmallVector<OpFoldResult>
675getSplitReductionIvs(RewriterBase &rewriter, Location loc,
676 ReductionTilingStrategy reductionStrategy, ValueRange ivs,
677 ArrayRef<OpFoldResult> numThreads,
678 ArrayRef<OpFoldResult> tileSizes,
679 const SetVector<unsigned> &reductionDims) {
680 SmallVector<OpFoldResult> splitReductionIvs;
681 splitReductionIvs.resize(N: reductionDims.size(), NV: rewriter.getIndexAttr(value: 0));
682 AffineExpr s0, s1;
683 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
684 AffineExpr divExpr = s0.floorDiv(other: s1);
685 int ivIndex = 0;
686 if (reductionStrategy ==
687 ReductionTilingStrategy::PartialReductionOuterParallel) {
688 for (auto [index, reductionDim] : llvm::enumerate(First: reductionDims)) {
689 if (!numThreads.empty()) {
690 splitReductionIvs[index] = ivs[ivIndex++];
691 continue;
692 }
693 splitReductionIvs[index] = affine::makeComposedFoldedAffineApply(
694 b&: rewriter, loc, expr: divExpr,
695 operands: ArrayRef<OpFoldResult>{ivs[ivIndex++], tileSizes[reductionDim]});
696 }
697 }
698 return splitReductionIvs;
699}
700
701static FailureOr<TilingResult>
702getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
703 ReductionTilingStrategy reductionStrategy,
704 ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
705 ArrayRef<OpFoldResult> sizes, ValueRange ivs,
706 ArrayRef<OpFoldResult> numThreads,
707 ArrayRef<OpFoldResult> tileSizes,
708 const SetVector<unsigned> &reductionDims) {
709 if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
710 return op.getTiledImplementation(b&: rewriter, offsets, sizes);
711 }
712
713 auto redOp = dyn_cast<PartialReductionOpInterface>(Val: op.getOperation());
714 if (!redOp) {
715 return rewriter.notifyMatchFailure(
716 arg&: op, msg: "PartialReductionOuterReduction tiling strategy is only "
717 "supported for operations "
718 "implementing PartialReductionOpInterface");
719 }
720
721 SmallVector<OpFoldResult> splitReductionIvs =
722 getSplitReductionIvs(rewriter, loc: op.getLoc(), reductionStrategy, ivs,
723 numThreads, tileSizes, reductionDims);
724 return redOp.tileToPartialReduction(b&: rewriter, loc: op.getLoc(), tilingStrategy: reductionStrategy,
725 init: regionIterArg, offsets, sizes,
726 reductionDims, splitReductionIvs);
727}
728
729static LogicalResult getResultTilePosition(
730 RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy,
731 int64_t index, Value tiledResult, TilingInterface op,
732 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
733 ValueRange ivs, ArrayRef<OpFoldResult> numThreads,
734 ArrayRef<OpFoldResult> tileSizes, const SetVector<unsigned> &reductionDims,
735 SmallVector<OpFoldResult> &resultOffset,
736 SmallVector<OpFoldResult> &resultSize) {
737
738 if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
739 return op.getResultTilePosition(b&: rewriter, resultNumber: index, offsets, sizes,
740 resultOffsets&: resultOffset, resultSizes&: resultSize);
741 }
742 auto redOp = dyn_cast<PartialReductionOpInterface>(Val: op.getOperation());
743 if (!redOp) {
744 return rewriter.notifyMatchFailure(
745 arg&: op, msg: "PartialReductionOuterReduction tiling strategy is only supported"
746 "for operations implementing PartialReductionOpInterface");
747 }
748 SmallVector<OpFoldResult> splitReductionIvs =
749 getSplitReductionIvs(rewriter, loc: op.getLoc(), reductionStrategy, ivs,
750 numThreads, tileSizes, reductionDims);
751 return redOp.getPartialResultTilePosition(
752 b&: rewriter, resultNumber: index, tilingStrategy: reductionStrategy, offsets, sizes, reductionDims,
753 splitReductionIvs, resultOffsets&: resultOffset, resultSizes&: resultSize);
754}
755
756static FailureOr<MergeResult>
757mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
758 ReductionTilingStrategy reductionStrategy,
759 const SetVector<unsigned> &reductionDims,
760 ValueRange partialResults) {
761 assert(reductionStrategy != ReductionTilingStrategy::FullReduction &&
762 "expected merge to be called for only partial reduction cases");
763
764 auto redOp = dyn_cast<PartialReductionOpInterface>(Val: op.getOperation());
765 if (!redOp) {
766 return rewriter.notifyMatchFailure(
767 arg&: op, msg: "PartialReductionOuterReduction tiling strategy is only "
768 "supported for operations "
769 "implementing PartialReductionOpInterface");
770 }
771 return redOp.mergeReductions(b&: rewriter, loc: op.getLoc(), partialReduce: partialResults,
772 reductionDims);
773}
774
775/// Append the specified additional `newInitOperands` operands to the
776/// loops existing `init` operands (or similar), and replace `loopOp` with
777/// the new loop that has the additional init operands. The loop body of
778/// this loop is moved over to the new loop. `yieldTiledValuesFn`
779/// is called to get the new tiled values returned, and the offset
780/// and sizes at which the tiled value is inserted into the
781/// new region iter_args that correspond to the newly added init operands.
782template <typename LoopType>
783FailureOr<LoopLikeOpInterface>
784yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
785 ValueRange newInitOperands,
786 YieldTiledValuesFn yieldTiledValuesFn) {
787 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
788}
789
790/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
791template <>
792FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
793 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
794 YieldTiledValuesFn yieldTiledValuesFn) {
795 OpBuilder::InsertionGuard g(rewriter);
796 Location loc = loopOp.getLoc();
797 rewriter.setInsertionPoint(loopOp);
798
799 auto inits = llvm::to_vector(Range: loopOp.getInitArgs());
800 inits.append(in_start: newInitOperands.begin(), in_end: newInitOperands.end());
801 auto newLoop = rewriter.create<scf::ForOp>(
802 location: loc, args: loopOp.getLowerBound(), args: loopOp.getUpperBound(), args: loopOp.getStep(),
803 args&: inits, args: [](OpBuilder &, Location, Value, ValueRange) {});
804
805 // Move the loop body to the new op.
806 Block *loopBody = loopOp.getBody();
807 Block *newLoopBody = newLoop.getBody();
808 rewriter.mergeBlocks(
809 source: loopBody, dest: newLoopBody,
810 argValues: newLoopBody->getArguments().take_front(N: loopBody->getNumArguments()));
811
812 auto yieldOp = cast<scf::YieldOp>(Val: newLoopBody->getTerminator());
813 rewriter.setInsertionPoint(yieldOp);
814
815 SmallVector<Value> tiledValues;
816 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
817 ValueRange newRegionIterArgs =
818 newLoop.getRegionIterArgs().take_back(N: newInitOperands.size());
819 if (failed(Result: yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
820 newRegionIterArgs, tiledValues, resultOffsets,
821 resultSizes))) {
822 rewriter.eraseOp(op: newLoop);
823 return rewriter.notifyMatchFailure(arg&: loopOp, msg: "failed to get tiled values");
824 }
825
826 SmallVector<Value> newYieldValues = llvm::to_vector(Range: yieldOp.getOperands());
827 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
828 llvm::zip_equal(t&: tiledValues, u&: newRegionIterArgs, args&: resultOffsets,
829 args&: resultSizes)) {
830 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
831 rewriter.getIndexAttr(value: 1));
832 Value insert = rewriter.create<tensor::InsertSliceOp>(
833 location: yieldOp->getLoc(), args&: tiledValue, args&: regionIterArg, args&: resultOffset, args&: resultSize,
834 args&: resultStride);
835 newYieldValues.push_back(Elt: insert);
836 }
837
838 rewriter.replaceOpWithNewOp<scf::YieldOp>(op: yieldOp, args&: newYieldValues);
839 rewriter.replaceOp(op: loopOp,
840 newValues: newLoop->getResults().take_front(n: loopOp.getNumResults()));
841 return cast<LoopLikeOpInterface>(Val: newLoop.getOperation());
842}
843
844/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
845template <>
846FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
847 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
848 YieldTiledValuesFn yieldTiledValuesFn) {
849 OpBuilder::InsertionGuard g(rewriter);
850 Location loc = loopOp.getLoc();
851 rewriter.setInsertionPoint(loopOp);
852 auto inits = llvm::to_vector(Range: loopOp.getOutputs());
853 inits.append(in_start: newInitOperands.begin(), in_end: newInitOperands.end());
854 auto newLoop = rewriter.create<scf::ForallOp>(
855 location: loc, args: loopOp.getMixedLowerBound(), args: loopOp.getMixedUpperBound(),
856 args: loopOp.getMixedStep(), args&: inits, args: loopOp.getMapping(),
857 args: [](OpBuilder &, Location, ValueRange) {});
858
859 // Move the region of the current block to the newly created op.
860 Block *loopBody = loopOp.getBody();
861 Block *newLoopBody = newLoop.getBody();
862 rewriter.mergeBlocks(
863 source: loopBody, dest: newLoopBody,
864 argValues: newLoopBody->getArguments().take_front(N: loopBody->getNumArguments()));
865
866 auto terminator = cast<scf::InParallelOp>(Val: newLoopBody->getTerminator());
867 rewriter.setInsertionPoint(terminator);
868 SmallVector<Value> tiledValues;
869 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
870 ValueRange regionIterArgs =
871 newLoop.getRegionIterArgs().take_back(N: newInitOperands.size());
872 if (failed(Result: yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
873 regionIterArgs, tiledValues, resultOffsets,
874 resultSizes))) {
875 rewriter.eraseOp(op: newLoop);
876 return rewriter.notifyMatchFailure(arg&: loopOp,
877 msg: "failed to get yielded tiled values");
878 }
879
880 // Update the terminator.
881 rewriter.setInsertionPointToEnd(terminator.getBody());
882
883 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
884 t&: tiledValues, u&: regionIterArgs, args&: resultOffsets, args&: resultSizes)) {
885 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
886 rewriter.getIndexAttr(value: 1));
887 rewriter.create<tensor::ParallelInsertSliceOp>(
888 location: terminator.getLoc(), args&: tiledValue, args&: iterArg, args&: resultOffset, args&: resultSize,
889 args&: resultStride);
890 }
891
892 rewriter.replaceOp(op: loopOp,
893 newValues: newLoop->getResults().take_front(n: loopOp.getNumResults()));
894 return cast<LoopLikeOpInterface>(Val: newLoop.getOperation());
895}
896
897/// Implementation of `yieldTiledValuesAndReplaceLoop` for
898/// `LoopLikeOpInterface`, that just dispatches to the implementation for each
899/// supported loop type.
900FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
901 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
902 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
903 return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
904 loopLikeOp.getOperation())
905 .Case<scf::ForOp, scf::ForallOp>(
906 caseFn: [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
907 return yieldTiledValuesAndReplaceLoop(
908 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
909 })
910 .Default(defaultFn: [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
911 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
912 });
913}
914
915/// Method to add new init values to a loop nest. Updates `loops` in-place
916/// with new loops that use the `newInitValues`. The outer-loops are updated
917/// to yield the new result values of the inner loop. For the innermost loop,
918/// the call back `getNewYields` is invoked to get the additional values to
919/// yield form the innermost loop.
920static LogicalResult addInitOperandsToLoopNest(
921 RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
922 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
923 if (loops.empty())
924 return success();
925 OpBuilder::InsertionGuard g(rewriter);
926 rewriter.setInsertionPoint(loops.front());
927
928 SmallVector<Value> ivs;
929 for (auto &loop : loops.drop_back()) {
930 rewriter.setInsertionPoint(loop);
931
932 // if loops.size() > 1 we assume that scf.for is used for the loops.
933 auto forLoop = cast<scf::ForOp>(Val: loop.getOperation());
934
935 // Create a new loop with the new init values for this loop.
936 SmallVector<Value> newInits = llvm::to_vector(Range: forLoop.getInitArgs());
937 newInits.append(in_start: newInitValues.begin(), in_end: newInitValues.end());
938 auto newLoop = rewriter.create<scf::ForOp>(
939 location: forLoop.getLoc(), args: forLoop.getLowerBound(), args: forLoop.getUpperBound(),
940 args: forLoop.getStep(), args&: newInits,
941 args: [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
942
943 // Merge the body of the new loop with the body of the old loops.
944 SmallVector<Value> sourceBlockArgs;
945 sourceBlockArgs.push_back(Elt: newLoop.getInductionVar());
946 auto newRegionIterArgs = newLoop.getRegionIterArgs();
947 sourceBlockArgs.append(
948 in_start: newRegionIterArgs.begin(),
949 in_end: std::next(x: newRegionIterArgs.begin(), n: forLoop.getNumResults()));
950 rewriter.mergeBlocks(source: forLoop.getBody(), dest: newLoop.getBody(), argValues: sourceBlockArgs);
951 rewriter.replaceOp(
952 op: forLoop, newValues: newLoop.getResults().take_front(n: forLoop.getNumResults()));
953 loop = newLoop;
954 ivs.push_back(Elt: newLoop.getInductionVar());
955 newInitValues = newLoop.getRegionIterArgs().take_back(N: newInitValues.size());
956 }
957
958 // Update the loop body of the innermost loop to get new yield values.
959 LoopLikeOpInterface innerMostLoop = loops.back();
960 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
961 yieldTiledValuesAndReplaceLoop(loopLikeOp: innerMostLoop, rewriter, newInitOperands: newInitValues,
962 yieldTiledValuesFn: getNewTiledYieldsFn);
963
964 if (failed(Result: newInnerMostLoop))
965 return innerMostLoop.emitOpError(message: "failed to return additional yields");
966 loops.back() = newInnerMostLoop.value();
967
968 // Make all other loops except the innermost loops yield the values returned
969 // by the inner loop.
970 for (auto [outerLoop, innerLoop] :
971 llvm::zip_equal(t: loops.drop_back(), u: loops.drop_front())) {
972 // Again assume that all the outer loops are scf.for operations.
973 auto outerForLoop = cast<scf::ForOp>(Val&: outerLoop);
974 auto outerLoopYield =
975 cast<scf::YieldOp>(Val: outerForLoop.getBody()->getTerminator());
976 SmallVector<Value> newYields =
977 llvm::to_vector(Range: outerLoopYield.getOperands());
978 ValueRange additionalYields =
979 innerLoop->getResults().take_back(n: newInitValues.size());
980 newYields.append(in_start: additionalYields.begin(), in_end: additionalYields.end());
981 rewriter.setInsertionPoint(outerLoopYield);
982 rewriter.replaceOpWithNewOp<scf::YieldOp>(op: outerLoopYield, args&: newYields);
983 }
984 return success();
985}
986
987/// Implementation of tiling transformation of `op` that implements the
988/// `TilingInterface` using `scf.for` to iterate over the tiles.
989FailureOr<scf::SCFTilingResult>
990mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
991 const scf::SCFTilingOptions &options) {
992 if (failed(Result: verifyOptions(rewriter, loc: op.getLoc(), options))) {
993 return failure();
994 }
995
996 OpBuilder::InsertionGuard guard(rewriter);
997 rewriter.setInsertionPointAfter(op);
998
999 // 1. Get the range of the loops that are represented by the operation.
1000 SmallVector<Range> iterationDomain = op.getIterationDomain(b&: rewriter);
1001
1002 // 2. Materialize the tile sizes and/or number of threads;
1003 SmallVector<OpFoldResult> tileSizes, numThreads;
1004 std::tie(args&: tileSizes, args&: numThreads) =
1005 getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
1006
1007 // Check if it is safe to tile. This is hold over from previous iterations
1008 // of tile to for-all. Consider dropping it.
1009 if (failed(Result: checkTileSizes(op, loopType: options.loopType, reductionStrategy: options.reductionStrategy,
1010 tileSizes, numThreads))) {
1011 return failure();
1012 }
1013
1014 // Get the reduction dims
1015 SetVector<unsigned> reductionDims =
1016 getSanitizedReductionDims(tileSizes, options);
1017
1018 // 3. If there is an interchange specified, permute the iteration domain and
1019 // the tile sizes.
1020 SmallVector<int64_t> interchangeVector;
1021 if (!options.interchangeVector.empty()) {
1022 interchangeVector = fillInterchangeVector(interchangeVector: options.interchangeVector,
1023 iterationDomainSize: iterationDomain.size());
1024 assert(isPermutationVector(interchangeVector) &&
1025 "expected interchange vector to be a permutation");
1026
1027 applyPermutationToVector(inVec&: iterationDomain, permutation: interchangeVector);
1028 applyPermutationToVector(inVec&: tileSizes, permutation: interchangeVector);
1029 if (!numThreads.empty())
1030 applyPermutationToVector(inVec&: numThreads, permutation: interchangeVector);
1031 }
1032
1033 FailureOr<TilingResult> tilingResult;
1034 // 4. Define the lambda function used later to generate the body of the
1035 // innermost tiled loop.
1036 YieldTiledValuesFn innerYieldTiledValuesFn =
1037 [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
1038 ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
1039 SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
1040 SmallVector<SmallVector<OpFoldResult>> &resultSizes)
1041 -> LogicalResult {
1042 // 4a. Compute the `offsets` and `sizes` to use for tiling.
1043 SmallVector<OpFoldResult> offsets, sizes;
1044 std::tie(args&: offsets, args&: sizes) = getTileOffsetAndSizes(
1045 rewriter, loc, strategy: options.reductionStrategy, ivs, iterationDomain,
1046 tileSizes, numThreads, reductionDims);
1047
1048 // 4b. If interchange was provided, apply inverse of the interchange
1049 // to get back the offsets/sizes in the order to be specified.
1050 if (!interchangeVector.empty()) {
1051 auto inversePermutation = invertPermutationVector(permutation: interchangeVector);
1052 applyPermutationToVector(inVec&: offsets, permutation: inversePermutation);
1053 applyPermutationToVector(inVec&: sizes, permutation: inversePermutation);
1054 }
1055
1056 // 5. Generate the tiled implementation within the inner most loop.
1057
1058 // 5a. Clone the operation within the loop body.
1059 auto clonedOp = cast<TilingInterface>(
1060 Val: cloneOpAndUpdateDestinationArgs(rewriter, op, newDestArgs: regionIterArgs));
1061
1062 // 5b. Early return cloned op if tiling is not happening. We can not
1063 // return the original op because it could lead to `rewriter.replaceOp(op,
1064 // op->getResults())` and users would get crash.
1065 if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) {
1066 tiledResults.append(in_start: clonedOp->result_begin(), in_end: clonedOp->result_end());
1067 tilingResult =
1068 TilingResult{/*tiledOps=*/{clonedOp}, .tiledValues: clonedOp->getResults(),
1069 /*generatedSlices=*/{}};
1070 return success();
1071 }
1072
1073 // 5c. Tile the cloned operation.
1074 tilingResult = getTiledImplementation(
1075 rewriter, op: clonedOp, reductionStrategy: options.reductionStrategy, regionIterArg: regionIterArgs, offsets,
1076 sizes, ivs, numThreads, tileSizes, reductionDims);
1077 if (failed(Result: tilingResult)) {
1078 rewriter.eraseOp(op: clonedOp);
1079 return op.emitOpError(message: "faild to tile operation");
1080 }
1081
1082 // 5d. Delete the cloned operation.
1083 rewriter.eraseOp(op: clonedOp);
1084
1085 // 5e. Compute the offsets at which the result values are to be inserted
1086 // back into its destinations.
1087 for (auto [index, tiledValue] :
1088 llvm::enumerate(First&: tilingResult->tiledValues)) {
1089 tiledResults.push_back(Elt: tiledValue);
1090 SmallVector<OpFoldResult> resultOffset, resultSize;
1091 if (failed(Result: getResultTilePosition(
1092 rewriter, reductionStrategy: options.reductionStrategy, index, tiledResult: tiledValue, op,
1093 offsets, sizes, ivs, numThreads, tileSizes, reductionDims,
1094 resultOffset, resultSize))) {
1095 for (auto op : tilingResult->tiledOps) {
1096 rewriter.eraseOp(op);
1097 }
1098 return rewriter.notifyMatchFailure(
1099 arg&: op, msg: "failed to get slice of result produced");
1100 }
1101 resultOffsets.emplace_back(Args: std::move(resultOffset));
1102 resultSizes.emplace_back(Args: std::move(resultSize));
1103 }
1104
1105 return success();
1106 };
1107
1108 // 6. Find the destination tensors to use for the operation.
1109 FailureOr<SmallVector<Value>> maybeInits = createInitialTensorsForTiling(
1110 rewriter, op, reductionStrategy: options.reductionStrategy, iterationDomain, numThreads,
1111 tileSizes, reductionDims);
1112 if (failed(Result: maybeInits)) {
1113 return rewriter.notifyMatchFailure(
1114 arg&: op, msg: "unable to create initial tensors for tiling");
1115 }
1116 SmallVector<Value> &initTensors = maybeInits.value();
1117
1118 // 7. Generate the tiled loops nest using the callback defined above.
1119 SmallVector<LoopLikeOpInterface> loops;
1120 if (failed(Result: generateLoopNest(rewriter, loc: op.getLoc(), loopType: options.loopType,
1121 loopRanges: iterationDomain, tileSizes, numThreads,
1122 destinationTensors: initTensors, mappingVector: options.mappingVector,
1123 tiledBodyFn: innerYieldTiledValuesFn, loops)))
1124 return op.emitOpError(message: "failed to generate tiling loops");
1125 assert(succeeded(tilingResult) &&
1126 "expected tiling result to be computed after loop generation");
1127
1128 if (loops.empty()) {
1129 // If loops are empty, the tiled op is used as the replacement for the
1130 // untiled op.
1131 return scf::SCFTilingResult{.tiledOps: tilingResult->tiledOps,
1132 .initialValues: initTensors,
1133 .loops: loops,
1134 .replacements: tilingResult->tiledValues,
1135 .generatedSlices: tilingResult->generatedSlices,
1136 .mergeOps: {}};
1137 }
1138
1139 auto loopResults = llvm::map_to_vector(C: loops.front()->getResults(),
1140 F: [](OpResult r) -> Value { return r; });
1141
1142 // For the full reduction case, there is nothing more to do.
1143 if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
1144 return scf::SCFTilingResult{
1145 .tiledOps: tilingResult->tiledOps, .initialValues: initTensors, .loops: loops, .replacements: loopResults,
1146 .generatedSlices: tilingResult->generatedSlices, .mergeOps: {}};
1147 }
1148
1149 // The results of the loop needs to be merged.
1150 FailureOr<MergeResult> mergeResult = mergeTilingResults(
1151 rewriter, op, reductionStrategy: options.reductionStrategy, reductionDims, partialResults: loopResults);
1152 if (failed(Result: mergeResult)) {
1153 return rewriter.notifyMatchFailure(
1154 arg&: op, msg: "Failed to merge partial results from tiling");
1155 }
1156 return scf::SCFTilingResult{.tiledOps: tilingResult->tiledOps,
1157 .initialValues: initTensors,
1158 .loops: loops,
1159 .replacements: mergeResult->replacements,
1160 .generatedSlices: tilingResult->generatedSlices,
1161 .mergeOps: mergeResult->mergeOps};
1162}
1163
1164FailureOr<scf::SCFTilingResult>
1165mlir::scf::tileReductionUsingScf(RewriterBase &b,
1166 PartialReductionOpInterface op,
1167 ArrayRef<OpFoldResult> tileSize) {
1168 scf::SCFTilingOptions options;
1169 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1170 options.setReductionTilingStrategy(
1171 ReductionTilingStrategy::PartialReductionOuterReduction);
1172 options.setTileSizes(tileSize);
1173 SmallVector<unsigned> reductionDims;
1174 for (auto [index, iteratorType] : llvm::enumerate(First: op.getLoopIteratorTypes()))
1175 if (iteratorType == utils::IteratorType::reduction)
1176 reductionDims.push_back(Elt: index);
1177 options.setReductionDims(reductionDims);
1178 return tileUsingSCF(rewriter&: b, op, options);
1179}
1180
1181//===----------------------------------------------------------------------===//
1182// tileConsumerAndFuseProducersUsingSCF implementation.
1183//===----------------------------------------------------------------------===//
1184
1185/// Return the untiled producer whose slice is used in a tiled consumer. The
1186/// method traverses the tile loop nest (`loops`) if needed, and returns the
1187/// `iter_args` of the outer most that is encountered. Traversing the
1188/// iter_args indicates that this is a destination operand of the consumer. If
1189/// there was no loop traversal needed, the second value of the returned tuple
1190/// is empty.
1191static std::tuple<OpResult, std::optional<OpOperand *>>
1192getUntiledProducerFromSliceSource(OpOperand *source,
1193 ArrayRef<LoopLikeOpInterface> loops) {
1194 std::optional<OpOperand *> destinationIterArg;
1195 assert(!loops.empty() && "expected non empty loops container");
1196 auto loopIt = loops.rbegin();
1197 while (loopIt != loops.rend() && isa<BlockArgument>(Val: source->get())) {
1198 auto iterArg = cast<BlockArgument>(Val: source->get());
1199 auto loop = *loopIt;
1200 if (iterArg.getOwner()->getParentOp() != loop)
1201 break;
1202 source = loop.getTiedLoopInit(bbArg: iterArg);
1203 loopIt++;
1204 }
1205 if (loopIt == loops.rend())
1206 destinationIterArg = source;
1207 return {dyn_cast<OpResult>(Val: source->get()), destinationIterArg};
1208}
1209
1210/// Implementation of fusing producer of a single slice by computing the
1211/// slice of the producer in-place.
1212std::optional<scf::SCFFuseProducerOfSliceResult>
1213mlir::scf::tileAndFuseProducerOfSlice(
1214 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1215 MutableArrayRef<LoopLikeOpInterface> loops) {
1216 // 1. Get the producer of the source (potentially walking through
1217 // `iter_args` of nested `scf.for`)
1218 auto [fusableProducer, destinationInitArg] =
1219 getUntiledProducerFromSliceSource(source: &candidateSliceOp.getSourceMutable(),
1220 loops);
1221 if (!fusableProducer)
1222 return std::nullopt;
1223 unsigned resultNumber = fusableProducer.getResultNumber();
1224
1225 OpBuilder::InsertionGuard g(rewriter);
1226 rewriter.setInsertionPoint(candidateSliceOp);
1227
1228 // 2. Clone the fused producer
1229 // 2a. Compute the destination operands to use for the cloned operation.
1230 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1231 Operation *fusableProducerOp = fusableProducer.getOwner();
1232 if (isa<DestinationStyleOpInterface>(Val: fusableProducerOp) &&
1233 failed(Result: tensor::getOrCreateDestinations(
1234 b&: rewriter, loc: fusableProducerOp->getLoc(), op: fusableProducerOp,
1235 result&: origDestinationTensors)))
1236 return std::nullopt;
1237
1238 clonedOpDestinationTensors = origDestinationTensors;
1239 if (destinationInitArg &&
1240 isa<DestinationStyleOpInterface>(Val: fusableProducerOp)) {
1241 // 2b. If the producer is also destination style, then to maintain the
1242 // destination passing style, update the destination of the producer to be
1243 // the source of the slice.
1244 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1245 }
1246 // 2c. Clone the fused producer.
1247 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1248 rewriter, op: fusableProducerOp, newDestArgs: clonedOpDestinationTensors);
1249 // 2d. Update the source of the candidateSlice to be the cloned producer.
1250 // Easier to just clone the slice with different source since
1251 // replacements and DCE of cloned ops becomes easier
1252 SmallVector<Value> candidateSliceOpOperands =
1253 llvm::to_vector(Range: candidateSliceOp->getOperands());
1254 candidateSliceOpOperands[0] = clonedProducerOp->getResult(idx: resultNumber);
1255 tensor::ExtractSliceOp clonedCandidateSliceOp =
1256 mlir::clone(b&: rewriter, op: candidateSliceOp,
1257 newResultTypes: candidateSliceOp->getResultTypes(), newOperands: candidateSliceOpOperands);
1258
1259 // 3. Generate the tiled implementation of the producer of the source
1260 FailureOr<TilingResult> tileAndFuseResult =
1261 tensor::replaceExtractSliceWithTiledProducer(
1262 builder&: rewriter, sliceOp: clonedCandidateSliceOp,
1263 producerOp: clonedProducerOp->getResult(idx: resultNumber));
1264 if (failed(Result: tileAndFuseResult))
1265 return std::nullopt;
1266 // Note: Do not delete the candidateSliceOp, since its passed in from the
1267 // caller.
1268 rewriter.replaceAllUsesWith(from: candidateSliceOp,
1269 to: tileAndFuseResult->tiledValues[0]);
1270 rewriter.eraseOp(op: clonedCandidateSliceOp);
1271 rewriter.eraseOp(op: clonedProducerOp);
1272
1273 // 3. If the slice is for a destination operand, for example,
1274 //
1275 // ```mlir
1276 // %0 = linalg.init
1277 // %1 = linalg.fill .. outs(%0 : )
1278 // %2 = scf.for .. iter_args(%arg0 = %1) {
1279 // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1280 // %4 = tensor.extract_slice %arg1 [..]
1281 // .. = linalg.matmul .. outs(%4 : )
1282 // }
1283 // }
1284 // ```
1285 //
1286 // the IR is currently
1287 //
1288 // ```
1289 // %0 = linalg.init
1290 // %1 = linalg.fill
1291 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1292 // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1293 // %4 = tensor.extract_slice %arg1[..]
1294 // %5 = linalg.fill .. outs(%4 : )
1295 // .. = linalg.matmul .. outs(%5 : )
1296 // }
1297 // }
1298 // ```
1299 //
1300 // The untiled `linalg.fill` is still used as the `init_value` since it
1301 // was originally a destination operand of the untiled `linalg.matmul`.
1302 // When fusing an operand that is a destination operand, the iter_arg of
1303 // the outer most loop should be changed to use the destination of the
1304 // fused operation. With this the IR will be.
1305 //
1306 // ```
1307 // %0 = linalg.init
1308 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1309 // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1310 // %3 = tensor.extract_slice %arg1[..]
1311 // %4 = linalg.fill .. outs(%3 : )
1312 // .. = linalg.matmul .. outs(%4 : )
1313 // }
1314 // }
1315 // ```
1316 if (destinationInitArg &&
1317 isa<DestinationStyleOpInterface>(Val: fusableProducerOp) && !loops.empty()) {
1318 loops.front()
1319 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1320 .set(origDestinationTensors[resultNumber]);
1321 }
1322 return scf::SCFFuseProducerOfSliceResult{
1323 .origProducer: fusableProducer, .tiledAndFusedProducer: tileAndFuseResult->tiledValues[0],
1324 .tiledOps: tileAndFuseResult->tiledOps, .generatedSlices: tileAndFuseResult->generatedSlices};
1325}
1326
1327/// Reconstruct the fused producer from within the tiled-and-fused code.
1328FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1329 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1330 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1331 MutableArrayRef<LoopLikeOpInterface> loops,
1332 ArrayRef<unsigned> yieldResultNumber) {
1333 if (loops.empty())
1334 return success();
1335
1336 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1337 *tiledOwner = fusedProducerInfo.tiledOps[0];
1338
1339 Location loc = originalOwner->getLoc();
1340 // a. collect all init Value to be appended
1341 SmallVector<unsigned> initNumberList =
1342 yieldResultNumber.empty() ? llvm::to_vector(Range: llvm::seq<unsigned>(
1343 Begin: 0, End: originalOwner->getNumResults()))
1344 : llvm::to_vector(Range&: yieldResultNumber);
1345 SmallVector<Value> initValueList;
1346 for (const auto &resultNumber : initNumberList) {
1347 FailureOr<Value> initValue = tensor::getOrCreateDestination(
1348 b&: rewriter, loc, opResult: originalOwner->getResult(idx: resultNumber));
1349 if (succeeded(Result: initValue)) {
1350 initValueList.push_back(Elt: initValue.value());
1351 } else {
1352 return failure();
1353 }
1354 }
1355
1356 SmallVector<Operation *> generatedSlices;
1357 YieldTiledValuesFn newYieldValuesFn =
1358 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1359 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1360 SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1361 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1362 OpBuilder::InsertionGuard g(innerRewriter);
1363
1364 // get sliceOp tile information
1365 SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1366 sliceSizes = sliceOp.getMixedSizes();
1367
1368 // expect all strides of sliceOp being 1
1369 if (!llvm::all_of(Range: sliceOp.getMixedStrides(), P: isOneInteger))
1370 return failure();
1371
1372 unsigned sliceResultNumber =
1373 fusedProducerInfo.origProducer.getResultNumber();
1374
1375 auto tilableOp = cast<TilingInterface>(Val: originalOwner);
1376 // b. get iterDomain Offset and Sizes based on sliceOp tile
1377 SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1378 // skip tensor.pack/unpack/pad, which expects single opResult
1379 if (tilableOp->getNumResults() > 1 &&
1380 failed(Result: tilableOp.getIterationDomainTileFromResultTile(
1381 b&: rewriter, resultNumber: sliceResultNumber, offsets: sliceOffset, sizes: sliceSizes,
1382 iterDomainOffsets&: iterDomainOffset, iterDomainSizes))) {
1383 // In theory, it is unnecessary to raise an error here. Actually
1384 // although it fails to reconstruct the result tensor, it should not
1385 // broke current fusion anyway. The reason why we must return failure
1386 // currently is that the callback function `newYieldValuesFn` will be
1387 // called after new init operand(s) has already been appended. It will
1388 // take more refactoring to make sure the init operands are added
1389 // consistently in the future. For more details, please refer to:
1390 // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1391 return failure();
1392 }
1393
1394 // c. calculate offsets and sizes info of all OpResults respectively based
1395 // on iteration Domain Tile
1396 SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1397 for (const auto &resultNumber : initNumberList) {
1398 if (resultNumber == sliceResultNumber) {
1399 offsetList.push_back(Elt: sliceOffset);
1400 sizesList.push_back(Elt: sliceSizes);
1401 } else {
1402 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1403 // infer result tile according to the iteration domain tile
1404 SmallVector<OpFoldResult> offset, sizes;
1405 if (failed(Result: tilableOp.getResultTilePosition(
1406 b&: rewriter, resultNumber, offsets: iterDomainOffset, sizes: iterDomainSizes,
1407 resultOffsets&: offset, resultSizes&: sizes))) {
1408 return failure();
1409 }
1410 offsetList.push_back(Elt: offset);
1411 sizesList.push_back(Elt: sizes);
1412 }
1413 }
1414
1415 // d. create `extract_slice` for `iter_args` for DPS operation if
1416 // necessary
1417 if (auto tiledDestStyleOp =
1418 dyn_cast<DestinationStyleOpInterface>(Val: tiledOwner)) {
1419 rewriter.setInsertionPoint(tiledDestStyleOp);
1420 for (const auto &&[index, newRegionArg] :
1421 llvm::enumerate(First&: newRegionIterArgs)) {
1422 auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1423 location: loc, args&: newRegionArg, args&: offsetList[index], args&: sizesList[index],
1424 args: SmallVector<OpFoldResult>(offsetList[index].size(),
1425 rewriter.getIndexAttr(value: 1)));
1426 generatedSlices.push_back(Elt: destSlice);
1427 unsigned resultNumber = initNumberList[index];
1428 rewriter.modifyOpInPlace(root: tiledDestStyleOp, callable: [&]() {
1429 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1430 });
1431 }
1432 }
1433
1434 // e. prepare tiled offset and sizes for later `insert_slice` creation by
1435 // caller
1436 Block *block = rewriter.getInsertionPoint()->getBlock();
1437 rewriter.setInsertionPoint(block->getTerminator());
1438 for (const auto &&[index, resultNumber] : llvm::enumerate(First&: initNumberList)) {
1439 tiledResult.push_back(Elt: tiledOwner->getResult(idx: resultNumber));
1440 tiledOffset.emplace_back(Args&: offsetList[index]);
1441 tiledSizes.emplace_back(Args&: sizesList[index]);
1442 }
1443 return success();
1444 };
1445
1446 if (failed(Result: addInitOperandsToLoopNest(rewriter, loops, newInitValues: initValueList,
1447 getNewTiledYieldsFn: newYieldValuesFn))) {
1448 return failure();
1449 }
1450 return generatedSlices;
1451}
1452
1453namespace {
1454
1455//===----------------------------------------------------------------------===//
1456// SliceTrackingListener
1457//===----------------------------------------------------------------------===//
1458
1459/// This class is a listener for tracking the insertion and removal of
1460/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1461/// fusion algorithm to apply cleanup patterns in between fusion steps.
1462class SliceTrackingListener : public RewriterBase::Listener {
1463public:
1464 explicit SliceTrackingListener(
1465 std::optional<FrozenRewritePatternSet> patterns);
1466 SliceTrackingListener() = default;
1467
1468 /// Adds the given list of operations to the worklist, and if present,
1469 /// applies the list of `patterns` to the newly added operations. This only
1470 /// processes the given operations and any newly inserted ones by the
1471 /// pattern set.
1472 LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1473
1474 /// Add to the new operation worklist if it is an extract_slice.
1475 void notifyOperationInserted(Operation *op,
1476 OpBuilder::InsertPoint previous) override;
1477
1478 /// Shared helper for operation removal from the worklist.
1479 void removeOp(Operation *op);
1480
1481 /// Remove the operation from the worklist.
1482 void notifyOperationErased(Operation *op) override;
1483
1484 /// Remove the operation from the worklist.
1485 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1486
1487 /// The worklist for this transformation keeps track of the slices to visit
1488 /// next for fusion.
1489 std::deque<tensor::ExtractSliceOp> worklist;
1490
1491private:
1492 /// Optional pattern set to apply when adding new operations to the
1493 /// worklist.
1494 std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1495};
1496
1497SliceTrackingListener::SliceTrackingListener(
1498 std::optional<FrozenRewritePatternSet> p) {
1499 patterns = std::move(p);
1500}
1501
1502LogicalResult
1503SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1504 for (Operation *op : ops) {
1505 if (auto slice = dyn_cast<tensor::ExtractSliceOp>(Val: op))
1506 worklist.push_back(x: slice);
1507 }
1508
1509 if (!patterns)
1510 return success();
1511
1512 return applyOpPatternsGreedily(
1513 ops, patterns: patterns.value(),
1514 config: GreedyRewriteConfig().setListener(this).setStrictness(
1515 GreedyRewriteStrictness::ExistingAndNewOps));
1516}
1517
1518void SliceTrackingListener::notifyOperationInserted(
1519 Operation *op, OpBuilder::InsertPoint previous) {
1520 auto slice = dyn_cast<tensor::ExtractSliceOp>(Val: op);
1521 if (!slice)
1522 return;
1523 worklist.push_back(x: slice);
1524}
1525
1526// Scan the worklist for the given op and remove it if present. The
1527// expectation is for the worklist to be small and for removal to be
1528// relatively rare.
1529void SliceTrackingListener::removeOp(Operation *op) {
1530 if (!isa<tensor::ExtractSliceOp>(Val: op))
1531 return;
1532 auto iter = worklist.begin();
1533 while (iter != worklist.end()) {
1534 if (*iter == op)
1535 break;
1536 iter++;
1537 }
1538 if (iter == worklist.end())
1539 return;
1540
1541 worklist.erase(position: iter);
1542}
1543
1544void SliceTrackingListener::notifyOperationErased(Operation *op) {
1545 removeOp(op);
1546}
1547
1548void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1549 ValueRange replacement) {
1550 removeOp(op);
1551}
1552
1553//===----------------------------------------------------------------------===//
1554// ReplacementListener
1555//===----------------------------------------------------------------------===//
1556
1557/// Listener that tracks updates replacements for values which can be mutated.
1558/// This listener runs on top of the existing listener for the rewriter,
1559/// to make sure external users can still run listeners.
1560class ReplacementListener : public RewriterBase::ForwardingListener {
1561public:
1562 ReplacementListener(DenseMap<Value, Value> &replacements,
1563 OpBuilder::Listener *listener)
1564 : ForwardingListener(listener), replacements(replacements) {}
1565
1566 void updateReplacementValues(ValueRange origValues,
1567 ValueRange replaceValues) {
1568 // This can probably be written better, but just iterates over the map
1569 // and the new replacements for now.
1570 for (auto &[key, val] : replacements) {
1571 for (auto [orig, replace] : llvm::zip_equal(t&: origValues, u&: replaceValues)) {
1572 if (val == orig) {
1573 val = replace;
1574 }
1575 }
1576 }
1577 }
1578
1579 void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1580 ForwardingListener::notifyOperationReplaced(op, newOp);
1581 updateReplacementValues(origValues: op->getResults(), replaceValues: newOp->getResults());
1582 }
1583
1584 void notifyOperationReplaced(Operation *op, ValueRange values) override {
1585 ForwardingListener::notifyOperationReplaced(op, replacement: values);
1586 updateReplacementValues(origValues: op->getResults(), replaceValues: values);
1587 }
1588
1589private:
1590 DenseMap<Value, Value> &replacements;
1591};
1592
1593} // namespace
1594
1595/// Implementation of tile consumer and fuse producer greedily.
1596FailureOr<scf::SCFTileAndFuseResult>
1597mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1598 RewriterBase &rewriter, TilingInterface consumer,
1599 const scf::SCFTileAndFuseOptions &options) {
1600 // This transformation is only valid for ops that return values (i.e. not
1601 // valid to use with operations that have memref operands).
1602 if (!consumer->getNumResults()) {
1603 return rewriter.notifyMatchFailure(
1604 arg&: consumer, msg: "invalid pattern for op with no results");
1605 }
1606
1607 // 1. First tile the consumer.
1608 SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1609
1610 FailureOr<scf::SCFTilingResult> tilingResult =
1611 tileUsingSCF(rewriter, op: consumer, options: options.tilingOptions);
1612
1613 if (failed(Result: tilingResult))
1614 return rewriter.notifyMatchFailure(arg&: consumer, msg: "failed to tile consumer");
1615 tiledAndFusedOps.insert_range(R&: tilingResult->tiledOps);
1616
1617 DenseMap<Value, Value> replacements;
1618 for (auto [origVal, replacement] :
1619 llvm::zip_equal(t: consumer->getResults(), u&: tilingResult->replacements)) {
1620 replacements[origVal] = replacement;
1621 }
1622
1623 // If there are no loops generated, fusion is immaterial.
1624 auto &loops = tilingResult->loops;
1625 if (loops.empty()) {
1626 return scf::SCFTileAndFuseResult{.fusedProducers: fusedProducers, .tiledAndFusedOps: tiledAndFusedOps, .loops: loops,
1627 .replacements: replacements};
1628 }
1629
1630 // Since the loop gets potentially replaced during fusion, we need to track
1631 // the mutation of replacement values. To do this, we attach a listener to
1632 // update the replacements as they happen.
1633 OpBuilder::Listener *previousListener = rewriter.getListener();
1634 auto resetListener =
1635 llvm::make_scope_exit(F: [&]() { rewriter.setListener(previousListener); });
1636 ReplacementListener replaceListener(replacements, previousListener);
1637 rewriter.setListener(&replaceListener);
1638
1639 // 2. Typically, the operands of the tiled operation are slices of the
1640 // operands of the untiled operation. These are expressed in IR using
1641 // `tensor.extract_slice` operations with source being the operands of
1642 // the untiled operation. Create a worklist of these
1643 // `tensor.extract_slice` operations. If the producers of the source of
1644 // the `tensor.extract_slice` can be tiled such that the tiled value is
1645 // generated in-place, that effectively tiles + fuses the operations.
1646 struct WorklistItem {
1647 tensor::ExtractSliceOp candidateSlice;
1648 SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1649 };
1650
1651 SliceTrackingListener sliceTracker =
1652 SliceTrackingListener(options.cleanupPatterns);
1653
1654 if (failed(
1655 Result: sliceTracker.insertAndApplyPatterns(ops: tilingResult->generatedSlices))) {
1656 return rewriter.notifyMatchFailure(arg&: consumer, msg: "cleanup patterns failed");
1657 }
1658 OpBuilder::InsertionGuard g(rewriter);
1659 while (!sliceTracker.worklist.empty()) {
1660 auto candidateSlice = sliceTracker.worklist.front();
1661 sliceTracker.worklist.pop_front();
1662
1663 auto [fusableProducer, destinationInitArg] =
1664 getUntiledProducerFromSliceSource(source: &candidateSlice.getSourceMutable(),
1665 loops);
1666 if (!fusableProducer)
1667 continue;
1668
1669 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1670 options.fusionControlFn(candidateSlice, fusableProducer,
1671 destinationInitArg.has_value());
1672 if (!controlFnResult)
1673 continue;
1674
1675 WorklistItem worklistItem = {.candidateSlice: candidateSlice, .controlFnResult: controlFnResult.value()};
1676
1677 // The operands of the fused producer might themselved be slices of
1678 // values produced by operations that implement the `TilingInterface`.
1679 // Add these operations to the worklist.
1680 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1681 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp: worklistItem.candidateSlice,
1682 loops);
1683 if (!fusedResult)
1684 continue;
1685
1686 SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1687
1688 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1689 // Reconstruct and yield all opResult of fusableProducerOp by default.
1690 // The caller can specific which one to yield by designating optional
1691 // argument named `yieldResultNumber` of
1692 // `yieldReplacementForFusedProducer`.
1693 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1694 FailureOr<SmallVector<Operation *>> newSlices =
1695 yieldReplacementForFusedProducer(rewriter,
1696 sliceOp: worklistItem.candidateSlice,
1697 fusedProducerInfo: fusedResult.value(), loops);
1698 if (failed(Result: newSlices)) {
1699 return rewriter.notifyMatchFailure(
1700 arg&: fusableProducerOp, msg: "failed to replacement value for this "
1701 "operation from within the tiled loop");
1702 }
1703 worklistCandidates.append(RHS: newSlices.value());
1704 for (auto [index, result] :
1705 llvm::enumerate(First: fusableProducerOp->getResults())) {
1706 replacements[result] = loops.front()->getResult(
1707 idx: loops.front()->getNumResults() -
1708 fusableProducerOp->getNumResults() + index);
1709 }
1710 }
1711 if (Operation *tiledAndFusedOp =
1712 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1713 fusedProducers.insert(X: fusedResult->origProducer.getDefiningOp());
1714 tiledAndFusedOps.insert(X: tiledAndFusedOp);
1715 }
1716
1717 if (failed(Result: sliceTracker.insertAndApplyPatterns(ops: worklistCandidates))) {
1718 return rewriter.notifyMatchFailure(arg&: consumer, msg: "cleanup patterns failed");
1719 }
1720 }
1721
1722 return scf::SCFTileAndFuseResult{.fusedProducers: fusedProducers, .tiledAndFusedOps: tiledAndFusedOps, .loops: loops,
1723 .replacements: replacements};
1724}
1725
1726//===----------------------------------------------------------------------===//
1727// tileAndFuseConsumerUsingSCF implementation.
1728//===----------------------------------------------------------------------===//
1729
1730/// A utility function that checks whether the only use of the result of a
1731/// tensor.insert_slice op is in a scf.yield op.
1732static LogicalResult
1733checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1734 Value result = candidateSliceOp.getResult();
1735 Value::use_range uses = result.getUses();
1736 if (!llvm::hasSingleElement(C&: uses)) {
1737 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1738 return failure();
1739 }
1740 OpOperand &operandUse = (*uses.begin());
1741 Operation *userOp = operandUse.getOwner();
1742 if (!isa<scf::YieldOp>(Val: userOp)) {
1743 LLVM_DEBUG(llvm::dbgs()
1744 << "Expected scf.yield to be the only user, but got -> "
1745 << (*userOp));
1746 return failure();
1747 }
1748 if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1749 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1750 "be in the same block\n");
1751 return failure();
1752 }
1753 return success();
1754}
1755
1756/// An utility to get the first user of the given loopOp. If any of user stay
1757/// in different block of loopOp, return failure.
1758static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1759 if (!isa<LoopLikeOpInterface>(Val: loopOp))
1760 return failure();
1761 Operation *firstUserOfLoop = nullptr;
1762 for (Operation *userOp : loopOp->getUsers()) {
1763 // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1764 // block with any other types of operation. Thus, just redirecting to its
1765 // parent `InParallelOp`. E.g.
1766 //
1767 // ```
1768 // %1 = scf.for {
1769 // ...
1770 // }
1771 // %2 = consumerOp ins(%1, ...)
1772 // scf.forall.in_parallel {
1773 // tensor.parallel_insert_slice %1
1774 // }
1775 // ```
1776 // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1777 // same block with `consumerOp`.
1778 if (isa<tensor::ParallelInsertSliceOp>(Val: userOp))
1779 userOp = userOp->getParentOfType<scf::InParallelOp>();
1780
1781 if (loopOp->getBlock() != userOp->getBlock())
1782 return failure();
1783
1784 if (!firstUserOfLoop || userOp->isBeforeInBlock(other: firstUserOfLoop))
1785 firstUserOfLoop = userOp;
1786 }
1787 return firstUserOfLoop;
1788}
1789
1790/// This utility currently checks whether the first userOp of loop is NOT
1791/// before the last defineOp of consumer operand. Because that we need to move
1792/// the whole loop structure right before the `firstUserOfLoop`. This utility
1793/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1794/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1795///
1796/// ```
1797/// %0 = scf.for() {
1798/// ...
1799/// }
1800/// ...
1801/// %1 = firstUserOfLoop(%0)
1802/// ...
1803/// %2 = lastDefOfConsumerOperand
1804/// ...
1805/// %3 = consumerOp(%2)
1806/// ```
1807///
1808/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1809/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1810/// a.k.a. use-def chain violation:
1811///
1812/// ```
1813/// %0:2 = scf.for() {
1814/// // use before define error
1815/// %3 = tiledConsumerOp(%2)
1816/// }
1817/// %1 = firstUserOfLoop(%0)
1818/// ...
1819/// %2 = lastDefOfConsumerOperand
1820/// ```
1821///
1822/// @param loopOp: loop operation
1823/// @param consumerOp: consumer operation
1824/// @param reorderOperations: the flag controls whether to reorder the
1825/// backward slice w.r.t. the defineOp of `consumerOp` operands.
1826/// @return: computed backward slice of consumerOp, but excluding those
1827/// already dominates `firstUserOfLoop`.
1828static FailureOr<llvm::SetVector<Operation *>>
1829checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
1830 bool reorderOperations) {
1831 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1832 if (failed(Result: firstUserOfLoop))
1833 return failure();
1834
1835 BackwardSliceOptions options;
1836 DominanceInfo dominanceInfo;
1837 options.inclusive = true;
1838 options.omitBlockArguments = true;
1839 bool includeLoopOp = false;
1840 options.filter = [&](Operation *op) {
1841 if (op == loopOp) {
1842 includeLoopOp = true;
1843 return false;
1844 }
1845 // Cut off the slice to not include any operation that already dominates
1846 // firstUserOfLoop.
1847 return !dominanceInfo.properlyDominates(a: op, b: *firstUserOfLoop);
1848 };
1849 llvm::SetVector<Operation *> slice;
1850 for (auto operand : consumerOp->getOperands()) {
1851 LogicalResult result = getBackwardSlice(root: operand, backwardSlice: &slice, options);
1852 assert(result.succeeded() && "expected a backward slice");
1853 (void)result;
1854 }
1855
1856 if (!slice.empty()) {
1857 // If consumerOp has one producer, which is also the user of loopOp.
1858 // E.g.
1859 // ```
1860 // %0 = %loopOp
1861 // %1 = consumerOp1 ins(%0)
1862 // %2 = consumerOp2 ins(%0, %1)
1863 // ```
1864 // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1865 // consumerOp1 has already been fused into loopOp before.
1866 if (includeLoopOp || !reorderOperations)
1867 return failure();
1868 }
1869
1870 return slice;
1871}
1872
1873/// Fetches the OpOperand of the first valid user (and use) of the value `val`
1874/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1875/// Returns failure otherwise.
1876static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1877 Operation *loopOp,
1878 unsigned resultNumber) {
1879 if (!isa<LoopLikeOpInterface>(Val: loopOp))
1880 return failure();
1881 Value val = loopOp->getResult(idx: resultNumber);
1882 Block *loopBlock = loopOp->getBlock();
1883 for (OpOperand &opOperand : val.getUses()) {
1884 Operation *consumerOp = opOperand.getOwner();
1885 // Step 1. Check if the user is tilable.
1886 if (!isa<TilingInterface>(Val: consumerOp) ||
1887 !isa<DestinationStyleOpInterface>(Val: consumerOp)) {
1888 // TODO: We have to init result of consumer before scf.for, use
1889 // DestinationStyleOpInterface to get result shape from init for now.
1890 // Add support for other op such as op has InferTypeOpInterface.
1891 continue;
1892 }
1893 // Step 2. Check if user stay in the same block.
1894 if (loopBlock != consumerOp->getBlock())
1895 continue;
1896 // Step 3. Check if user has succeeding user. Otherwise, it usually
1897 // represents already tiled.
1898 if (consumerOp->use_empty())
1899 continue;
1900 // Step 4. Check assumption for loop with `reorderOperations` enabled.
1901 FailureOr<llvm::SetVector<Operation *>> slice =
1902 checkAssumptionForLoop(loopOp, consumerOp, reorderOperations: true);
1903 if (failed(Result: slice))
1904 continue;
1905 // Step 5. If backward sice is not empty, move them before
1906 // firstUserOfLoop.
1907 if (!slice->empty()) {
1908 mlir::topologicalSort(toSort: *slice);
1909 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1910 assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1911 for (auto op : *slice) {
1912 rewriter.moveOpBefore(op, existingOp: *firstUserOfLoop);
1913 }
1914 }
1915 return &opOperand;
1916 }
1917 return failure();
1918}
1919
1920/// Check that the loop is perfectly nested.
1921/// The loops are expected to be ordered from outer most to inner most.
1922/// For example:
1923/// ```
1924/// %0 = scf.for()
1925/// %1 = scf.for()
1926/// %2 = scf.for()
1927/// %3 = ...
1928/// yield %3
1929/// yield %2
1930/// yield %1
1931/// ```
1932/// Here loops should be [%0, %1].
1933static bool
1934isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
1935 assert(!loops.empty() && "unexpected empty loop nest");
1936 if (loops.size() == 1) {
1937 return isa_and_nonnull<scf::ForOp>(Val: loops.front().getOperation());
1938 }
1939 for (auto [outerLoop, innerLoop] :
1940 llvm::zip_equal(t: loops.drop_back(), u: loops.drop_front())) {
1941 auto outerFor = dyn_cast_or_null<scf::ForOp>(Val: outerLoop.getOperation());
1942 auto innerFor = dyn_cast_or_null<scf::ForOp>(Val: innerLoop.getOperation());
1943 if (!outerFor || !innerFor) {
1944 return false;
1945 }
1946 auto outerBBArgs = outerFor.getRegionIterArgs();
1947 auto innerIterArgs = innerFor.getInitArgs();
1948 if (outerBBArgs.size() != innerIterArgs.size()) {
1949 return false;
1950 }
1951
1952 for (auto [outerBBArg, innerIterArg] :
1953 llvm::zip_equal(t&: outerBBArgs, u&: innerIterArgs)) {
1954 if (!llvm::hasSingleElement(C: outerBBArg.getUses()) ||
1955 innerIterArg != outerBBArg) {
1956 return false;
1957 }
1958 }
1959
1960 ValueRange outerYields =
1961 cast<scf::YieldOp>(Val: outerFor.getBody()->getTerminator())->getOperands();
1962 ValueRange innerResults = innerFor.getResults();
1963 if (outerYields.size() != innerResults.size()) {
1964 return false;
1965 }
1966 for (auto [outerYield, innerResult] :
1967 llvm::zip_equal(t&: outerYields, u&: innerResults)) {
1968 if (!llvm::hasSingleElement(C: innerResult.getUses()) ||
1969 outerYield != innerResult) {
1970 return false;
1971 }
1972 }
1973 }
1974 return true;
1975}
1976
1977/// Fetch the untiled consumer of the outermost scf.for's result which is
1978/// yielded by a tensor.insert_slice from the innermost scf.for. This function
1979/// makes the following assumptions :
1980/// 1. tensor.insert_slice has scf.yield as its only user.
1981/// 2. scf.for's corresponding result has only one use.
1982/// 3. The `loops` passed in are perfectly nested `scf.for` operations.
1983static FailureOr<OpOperand *>
1984getUntiledConsumerFromSlice(RewriterBase &rewriter,
1985 tensor::InsertSliceOp candidateSliceOp,
1986 MutableArrayRef<LoopLikeOpInterface> loops) {
1987 assert(!loops.empty() && "unexpected loops to be empty");
1988 // 1. Expect slice to be part of the body of the inner most loop.
1989 Operation *containingOp = candidateSliceOp->getParentOp();
1990 if (containingOp != loops.back()) {
1991 return rewriter.notifyMatchFailure(
1992 arg&: candidateSliceOp,
1993 msg: "expected slice to be within body of inner-most loop");
1994 }
1995
1996 // 2. Check that the loop is perfectly nested.
1997 if (!isPerfectlyNestedForLoops(loops)) {
1998 return rewriter.notifyMatchFailure(
1999 arg&: candidateSliceOp, msg: "expected passed loops to be perfectly nested.");
2000 }
2001
2002 if (failed(Result: checkAssumptionForFusingConsumer(candidateSliceOp)))
2003 return failure();
2004 Value sliceResult = candidateSliceOp.getResult();
2005
2006 // 3. Fetch the corresponding output.
2007 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
2008 unsigned resultNumber = yieldOpOperand.getOperandNumber();
2009
2010 scf::ForOp topLevelForOp = cast<scf::ForOp>(Val: loops.front().getOperation());
2011
2012 return getConsumerFromLoopUses(rewriter, loopOp: topLevelForOp, resultNumber);
2013}
2014
2015/// Fetch the first untiled consumer of a scf.forall's result which is yielded
2016/// by a tensor.parallel_insert_slice.
2017static FailureOr<OpOperand *>
2018getUntiledConsumerFromSlice(RewriterBase &rewriter,
2019 tensor::ParallelInsertSliceOp candidateSliceOp,
2020 MutableArrayRef<LoopLikeOpInterface> loops) {
2021 assert(!loops.empty() && "unexpected loops to be empty");
2022 // 1. Check that the surrounding loop is a single scf.forall loop.
2023 if (loops.size() != 1) {
2024 return rewriter.notifyMatchFailure(
2025 arg&: candidateSliceOp, msg: "expected single surrounding scf.forall");
2026 }
2027 auto forallOp = dyn_cast<scf::ForallOp>(Val: loops.front().getOperation());
2028 if (!forallOp) {
2029 return rewriter.notifyMatchFailure(
2030 arg&: candidateSliceOp, msg: "expected single surrounding scf.forall");
2031 }
2032
2033 // 2. Fetch the corresponding output
2034 Value sliceDest = candidateSliceOp.getDest();
2035 auto iterArg = dyn_cast<BlockArgument>(Val&: sliceDest);
2036 if (!iterArg)
2037 return failure();
2038 if (iterArg.getOwner()->getParentOp() != forallOp)
2039 return failure();
2040
2041 unsigned resultNumber =
2042 forallOp.getTiedOpResult(opOperand: forallOp.getTiedOpOperand(bbArg: iterArg))
2043 .getResultNumber();
2044
2045 return getConsumerFromLoopUses(rewriter, loopOp: forallOp, resultNumber);
2046}
2047
2048/// A utility to fetch an untiled consumer of
2049/// tensor.insert_slice/tensor.parallel_insert_slice.
2050static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
2051 RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
2052 MutableArrayRef<LoopLikeOpInterface> loops) {
2053 assert(!loops.empty() && "unexpected empty loops");
2054 assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
2055 SmallVector<OpOperand *> fusedOperands;
2056 for (auto sliceOp : sliceOps) {
2057 FailureOr<OpOperand *> fusedOperand =
2058 TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp)
2059 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2060 caseFn: [&](auto op) {
2061 return getUntiledConsumerFromSlice(rewriter, op, loops);
2062 })
2063 .Default(defaultFn: [&](Operation *op) {
2064 return rewriter.notifyMatchFailure(arg&: op, msg: "unhandled slice type");
2065 });
2066 if (failed(Result: fusedOperand)) {
2067 return failure();
2068 }
2069 if (!fusedOperands.empty() &&
2070 fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2071 return rewriter.notifyMatchFailure(
2072 arg: fusedOperand.value()->getOwner(),
2073 msg: "all candidate slices must be to the same consumer");
2074 }
2075 fusedOperands.push_back(Elt: fusedOperand.value());
2076 }
2077 return fusedOperands;
2078}
2079
2080template <typename InsertSliceOpTy>
2081static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
2082 InsertSliceOpTy sliceOp);
2083
2084template <>
2085tensor::InsertSliceOp
2086cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
2087 tensor::InsertSliceOp insertSliceOp) {
2088 return cast<tensor::InsertSliceOp>(
2089 Val: rewriter.clone(op&: *insertSliceOp.getOperation()));
2090}
2091
2092template <>
2093tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
2094 RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2095 return rewriter.create<tensor::InsertSliceOp>(
2096 location: insertSliceOp->getLoc(), args: insertSliceOp.getSource(),
2097 args: insertSliceOp.getDest(), args: insertSliceOp.getMixedOffsets(),
2098 args: insertSliceOp.getMixedSizes(), args: insertSliceOp.getMixedStrides());
2099}
2100
2101static SmallVector<tensor::InsertSliceOp>
2102cloneAsInsertSlices(RewriterBase &rewriter,
2103 ArrayRef<Operation *> candidateSlices) {
2104 assert(!candidateSlices.empty() &&
2105 "unexpected empty list of slices to clone");
2106 SmallVector<tensor::InsertSliceOp> clonedSlices;
2107 for (auto sliceOp : candidateSlices) {
2108 TypeSwitch<Operation *>(sliceOp)
2109 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2110 caseFn: [&](auto op) {
2111 auto clonedOp = cloneAsInsertSlice(rewriter, op);
2112 clonedSlices.push_back(Elt: clonedOp);
2113 })
2114 .Default(defaultFn: [&](Operation *op) {
2115 // Assert here assuming this has already been checked.
2116 assert(0 && "unexpected slice type while cloning as insert slice");
2117 });
2118 }
2119 return clonedSlices;
2120}
2121
2122/// Implementation of fusing consumer of a single slice by computing the
2123/// slice of the consumer in-place for scf loop.
2124FailureOr<scf::SCFFuseConsumerOfSliceResult>
2125mlir::scf::tileAndFuseConsumerOfSlices(
2126 RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2127 MutableArrayRef<LoopLikeOpInterface> loops) {
2128 if (candidateSlices.empty()) {
2129 return rewriter.notifyMatchFailure(
2130 arg: rewriter.getUnknownLoc(),
2131 msg: "no candidate slices provided for consumer fusion");
2132 }
2133 // Return if `loops` is empty, return an error for now. Caller is expected
2134 // to handle this case.
2135 if (loops.empty()) {
2136 return rewriter.notifyMatchFailure(
2137 arg: candidateSlices.front(),
2138 msg: "cannot call tile and fuse consumer with an empty loop nest");
2139 }
2140
2141 if (!(llvm::all_of(Range&: candidateSlices, P: llvm::IsaPred<tensor::InsertSliceOp>) ||
2142 llvm::all_of(Range&: candidateSlices,
2143 P: llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2144 return rewriter.notifyMatchFailure(
2145 arg: candidateSlices.front(),
2146 msg: "candidates slices need to be all `tensor.extract_slice`s or "
2147 "`tensor.parallel_insert_slice`s");
2148 }
2149
2150 // 1. Get the consumer of scf.for for the result yielded by
2151 // tensor.insert_slice/parallel_insert_slice.
2152 SmallVector<OpOperand *> consumerOpOperands;
2153 Operation *consumerOp;
2154 {
2155 FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2156 getUntiledConsumerOperandsFromSlices(rewriter, sliceOps: candidateSlices, loops);
2157 if (failed(Result: maybeConsumerOpOperand)) {
2158 return rewriter.notifyMatchFailure(arg: candidateSlices.front(),
2159 msg: "could not fetch consumer to fuse");
2160 }
2161 std::swap(LHS&: consumerOpOperands, RHS&: maybeConsumerOpOperand.value());
2162 consumerOp = consumerOpOperands.front()->getOwner();
2163 }
2164
2165 LoopLikeOpInterface outerMostLoop = loops.front();
2166 LoopLikeOpInterface innerMostLoop = loops.back();
2167
2168 // Check assumption for loop with `reorderOperations` disabled.
2169 if (failed(Result: checkAssumptionForLoop(loopOp: outerMostLoop, consumerOp, reorderOperations: false))) {
2170 return rewriter.notifyMatchFailure(
2171 arg&: outerMostLoop, msg: "the first user of loop should not dominate any define "
2172 "of consumer operand(s)");
2173 }
2174
2175 OpBuilder::InsertionGuard g(rewriter);
2176
2177 // 2. Check consumer is not using scf loop's output as init.
2178 auto dstOp = dyn_cast<DestinationStyleOpInterface>(Val: consumerOp);
2179 if (!dstOp)
2180 return rewriter.notifyMatchFailure(arg&: consumerOp,
2181 msg: "consumer op is not DPS operation");
2182 if (llvm::any_of(Range&: consumerOpOperands, P: [&](OpOperand *opOperand) {
2183 return dstOp.isDpsInit(opOperand);
2184 })) {
2185 return rewriter.notifyMatchFailure(
2186 arg&: consumerOp,
2187 msg: "consumer op taking the result of scf.for as init is not supported");
2188 }
2189 SmallVector<Value> newInits = llvm::to_vector(Range: dstOp.getDpsInits());
2190
2191 // 3. Move the whole loop structure right before firstUserOfLoop, the
2192 // dominance should be already ensured by `checkAssumptionForLoop`.
2193 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp: outerMostLoop);
2194 if (failed(Result: firstUserOfLoop)) {
2195 return rewriter.notifyMatchFailure(
2196 arg&: outerMostLoop, msg: "could not find the first user of outer most loop");
2197 }
2198 rewriter.moveOpBefore(op: outerMostLoop, existingOp: *firstUserOfLoop);
2199
2200 // 4. Set insertion point before terminator op of the loop and create a new
2201 // tensor.insert_slice. In the scf.for case this is a clone of the
2202 // candidateSliceOp whereas in the scf.forall case this is created from the
2203 // operands of tensor.parallel_insert_slice.
2204 if (auto sliceOp =
2205 dyn_cast<tensor::ParallelInsertSliceOp>(Val: candidateSlices.front())) {
2206 auto newForallOp = cast<scf::ForallOp>(Val: innerMostLoop.getOperation());
2207 rewriter.setInsertionPoint(newForallOp.getTerminator());
2208 } else {
2209 rewriter.setInsertionPoint(candidateSlices.front());
2210 }
2211 // 5.a. Clone all the candidate slices as equivalent insert slice ops.
2212 SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
2213 cloneAsInsertSlices(rewriter, candidateSlices);
2214
2215 // 5.b. Clone consumer op.
2216 auto clonedConsumerOp = cast<TilingInterface>(Val: rewriter.clone(op&: *consumerOp));
2217 SmallVector<unsigned> operandNumbers =
2218 llvm::map_to_vector(C&: consumerOpOperands, F: [](OpOperand *opOperand) {
2219 return opOperand->getOperandNumber();
2220 });
2221 SmallVector<OpOperand *> clonedOpFusedOperandsList =
2222 llvm::map_to_vector(C&: operandNumbers, F: [&](unsigned operandNum) {
2223 return &clonedConsumerOp->getOpOperand(idx: operandNum);
2224 });
2225
2226 // 5.c. Replace all uses of the loop result with the result of the cloned
2227 // tensor.insert_slice.
2228 rewriter.modifyOpInPlace(root: clonedConsumerOp, callable: [&]() {
2229 for (auto [operandToReplace, clonedSliceOp] :
2230 llvm::zip_equal(t&: clonedOpFusedOperandsList, u&: clonedInsertSlices)) {
2231 operandToReplace->set(clonedSliceOp.getResult());
2232 }
2233 });
2234
2235 // 6. Perform tiling of the cloned consumer and replace the operand at
2236 // `operandNumber` with the source of the cloned tensor.insert_slice op.
2237 FailureOr<TilingResult> tileAndFuseResult =
2238 tensor::replaceInsertSlicesWithTiledConsumer(builder&: rewriter, sliceOps: clonedInsertSlices,
2239 consumerOperands: clonedOpFusedOperandsList);
2240 if (failed(Result: tileAndFuseResult)) {
2241 return failure();
2242 }
2243
2244 auto tiledConsumerOp = cast<TilingInterface>(Val: tileAndFuseResult->tiledOps[0]);
2245 for (auto [operandNum, clonedSliceOp] :
2246 llvm::zip_equal(t&: operandNumbers, u&: clonedInsertSlices)) {
2247 rewriter.replaceAllUsesWith(from: tiledConsumerOp->getOperand(idx: operandNum),
2248 to: clonedSliceOp.getSource());
2249 }
2250
2251 // 7. Reconstruct [nested] loop with new inits.
2252 YieldTiledValuesFn newYieldValuesFn =
2253 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2254 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2255 SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
2256 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2257 OpBuilder::InsertionGuard g(innerRewriter);
2258 // 8. Set inner insertPoint right before tiled consumer op.
2259 innerRewriter.setInsertionPoint(tiledConsumerOp);
2260
2261 SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
2262 for (auto candidateSliceOp : clonedInsertSlices) {
2263 SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
2264 SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
2265 SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
2266
2267 // 9. Check all insert stride is 1.
2268 if (!llvm::all_of(Range&: strides, P: isOneInteger)) {
2269 return rewriter.notifyMatchFailure(
2270 arg&: candidateSliceOp, msg: "containingOp's result yield with stride");
2271 }
2272
2273 allOffsets.emplace_back(Args: std::move(offsets));
2274 allSizes.emplace_back(Args: std::move(sizes));
2275 }
2276
2277 // 10. Try to get iter domain position from input position. Use
2278 // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2279 // domain may require index computation based on the result size. The
2280 // sizes and offsets should be the same either way, but using
2281 // tiledConsumerOp could lead to some chained unnecessary extra index
2282 // computation.
2283 SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2284 if (failed(Result: clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2285 b&: rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2286 iterDomainSizes))) {
2287 return rewriter.notifyMatchFailure(
2288 arg&: clonedConsumerOp,
2289 msg: "can't get iter domain position from input position");
2290 }
2291
2292 // 11. Try to fetch the offset and size for all results of the cloned
2293 // consumer. This would then be used to form the corresponding
2294 // tensor.insert_slice/parallel_insert_slice later.
2295 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2296 SmallVector<SmallVector<OpFoldResult>> resultOffsets(
2297 totalNumResultsOfConsumer);
2298 SmallVector<SmallVector<OpFoldResult>> resultSizes(
2299 totalNumResultsOfConsumer);
2300 for (auto [idx, v] : llvm::enumerate(First: tiledConsumerOp->getResults())) {
2301 if (failed(Result: tiledConsumerOp.getResultTilePosition(
2302 b&: rewriter, resultNumber: idx, offsets: iterDomainOffsets, sizes: iterDomainSizes,
2303 resultOffsets&: resultOffsets[idx], resultSizes&: resultSizes[idx]))) {
2304 return rewriter.notifyMatchFailure(
2305 arg&: tiledConsumerOp,
2306 msg: "can't get result domain position from iter domain position");
2307 }
2308 }
2309
2310 // 12. Create `extract_slice` for `iter_args` for DPS operation if
2311 // necessary.
2312 if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2313 Val: tiledConsumerOp.getOperation())) {
2314 rewriter.setInsertionPoint(tiledDestStyleOp);
2315 for (const auto &&[index, newRegionArg] :
2316 llvm::enumerate(First&: newRegionIterArgs)) {
2317 auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2318 location: loc, args&: newRegionArg, args&: resultOffsets[index], args&: resultSizes[index],
2319 args: SmallVector<OpFoldResult>(resultOffsets[index].size(),
2320 rewriter.getIndexAttr(value: 1)));
2321 // Make a copy of index to avoid a capturing structured binding, which
2322 // is a C++20 extension.
2323 auto dstNumber = index;
2324 rewriter.modifyOpInPlace(root: tiledDestStyleOp, callable: [&]() {
2325 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2326 });
2327 }
2328 }
2329
2330 // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2331 // caller.
2332 Block *block = rewriter.getInsertionPoint()->getBlock();
2333 rewriter.setInsertionPoint(block->getTerminator());
2334 for (const auto &&[index, result] :
2335 llvm::enumerate(First: tiledConsumerOp->getResults())) {
2336 tiledResult.push_back(Elt: result);
2337 tiledOffset.emplace_back(Args&: resultOffsets[index]);
2338 tiledSizes.emplace_back(Args&: resultSizes[index]);
2339 }
2340 return success();
2341 };
2342 // 14. Add new inits to [nested] loops.
2343 if (failed(Result: addInitOperandsToLoopNest(rewriter, loops, newInitValues: newInits,
2344 getNewTiledYieldsFn: newYieldValuesFn))) {
2345 return rewriter.notifyMatchFailure(arg&: tiledConsumerOp,
2346 msg: "unable to add new inits to nest loop");
2347 }
2348
2349 // 15. Replace the result of scf loop and consumer op with new loop's
2350 // results.
2351
2352 for (auto &&[oldResult, newResult] :
2353 llvm::zip(t: consumerOp->getResults(),
2354 u: loops.front()->getResults().take_back(n: newInits.size()))) {
2355 rewriter.replaceAllUsesWith(from: oldResult, to: newResult);
2356 }
2357
2358 // 16. Need to erase the old scf loop and the cloned consumer op.
2359 rewriter.eraseOp(op: clonedConsumerOp);
2360
2361 SmallVector<OpOperand *> tiledAndFusedOpOperands =
2362 llvm::map_to_vector(C&: operandNumbers, F: [&](unsigned operandNum) {
2363 return &tileAndFuseResult->tiledOps[0]->getOpOperand(idx: operandNum);
2364 });
2365 return scf::SCFFuseConsumerOfSliceResult{
2366 .origConsumerOperands: std::move(consumerOpOperands), .tiledAndFusedConsumerOperands: std::move(tiledAndFusedOpOperands),
2367 .tiledOps: std::move(tileAndFuseResult->tiledOps)};
2368}
2369
2370//===----------------------------------------------------------------------===//
2371// lowerToLoopsUsingSCFForOp implementation.
2372//===----------------------------------------------------------------------===//
2373
2374FailureOr<SmallVector<scf::ForOp>>
2375mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
2376 TilingInterface op) {
2377 // TODO: Handle cases where the op has results if needed.
2378 if (op->getNumResults() > 0) {
2379 return rewriter.notifyMatchFailure(
2380 arg&: op, msg: "unable to lower to loops operations with return values");
2381 }
2382
2383 SmallVector<Range> domain = op.getIterationDomain(b&: rewriter);
2384 SmallVector<Value> ivs;
2385 SmallVector<scf::ForOp> loops;
2386 Location loc = op.getLoc();
2387 for (auto loopRange : domain) {
2388 Value offsetVal =
2389 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.offset);
2390 Value sizeVal =
2391 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.size);
2392 Value strideVal =
2393 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.stride);
2394 auto loop = rewriter.create<scf::ForOp>(location: op.getLoc(), args&: offsetVal, args&: sizeVal,
2395 args&: strideVal, args: ValueRange{});
2396 loops.push_back(Elt: loop);
2397 ivs.push_back(Elt: loop.getInductionVar());
2398 rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2399 }
2400 if (failed(Result: op.generateScalarImplementation(b&: rewriter, loc: op.getLoc(), ivs))) {
2401 return failure();
2402 }
2403 return loops;
2404}
2405

source code of mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp