1//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the tiling using TilingInterface.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14
15#include "mlir/Analysis/SliceAnalysis.h"
16#include "mlir/Analysis/TopologicalSortUtils.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Arith/Utils/Utils.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/SCF/Utils/Utils.h"
22#include "mlir/Dialect/Tensor/IR/Tensor.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/IR/Dominance.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Interfaces/DestinationStyleOpInterface.h"
28#include "mlir/Interfaces/TilingInterface.h"
29#include "mlir/Rewrite/FrozenRewritePatternSet.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/ADT/ScopeExit.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/Debug.h"
34#include <optional>
35
36#define DEBUG_TYPE "tile-using-interface"
37
38using namespace mlir;
39
40scf::SCFTilingOptions &
41scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
42 assert(!tileSizeComputationFunction && "tile sizes already set");
43 auto tileSizes = llvm::to_vector(Range&: ts);
44 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
45 return tileSizes;
46 };
47 return *this;
48}
49
50scf::SCFTilingOptions &
51scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
52 assert(!numThreadsComputationFunction && "num tiles already set");
53 auto numThreads = llvm::to_vector(Range&: nt);
54 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
55 return numThreads;
56 };
57 return *this;
58}
59
60/// Helper method to adjust the interchange vector to match the iteration
61/// domain.
62static SmallVector<int64_t>
63fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
64 size_t iterationDomainSize) {
65 SmallVector<int64_t> filledVector = llvm::to_vector(Range&: interchangeVector);
66 if (filledVector.size() < iterationDomainSize) {
67 auto range = llvm::seq<int64_t>(Begin: filledVector.size(), End: iterationDomainSize);
68 filledVector.append(in_start: range.begin(), in_end: range.end());
69 }
70 if (filledVector.size() > iterationDomainSize)
71 filledVector.resize(N: iterationDomainSize);
72 return filledVector;
73}
74
75//===----------------------------------------------------------------------===//
76// tileUsingSCF implementation.
77//===----------------------------------------------------------------------===//
78
79/// Verify the tile size options are set in a consistent manner.
80static LogicalResult
81verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
82 const scf::SCFTilingOptions &options) {
83 // Specifying number of threads is only supported on `scf.forall` op.
84 if (options.numThreadsComputationFunction &&
85 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
86 return rewriter.notifyMatchFailure(
87 arg&: loc, msg: "number of threads can only by specified when loop type is "
88 "set to use `scf.forall`");
89 }
90
91 // If specified, check that the interchange vector is a permutation.
92 if (!options.interchangeVector.empty()) {
93 if (!isPermutationVector(interchange: options.interchangeVector)) {
94 return rewriter.notifyMatchFailure(
95 arg&: loc, msg: "invalid interchange vector, not a permutation of the entire "
96 "iteration space");
97 }
98 }
99 return success();
100}
101
102/// Method to instantiate the tile sizes and/or number of threads specified
103/// by the user.
104static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
105getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
106 ArrayRef<Range> iterationDomain,
107 const scf::SCFTilingOptions &options) {
108 OpFoldResult zero = rewriter.getIndexAttr(0);
109 SmallVector<OpFoldResult> tileSizes, numThreads;
110 size_t numLoops = iterationDomain.size();
111
112 // Check whether the number of tiles to use is specified.
113 if (options.numThreadsComputationFunction) {
114 numThreads = options.numThreadsComputationFunction(rewriter, op);
115 numThreads.resize(N: numLoops, NV: zero);
116
117 // If the number of tiles is also specified, use that.
118 if (options.tileSizeComputationFunction) {
119 tileSizes = options.tileSizeComputationFunction(rewriter, op);
120 tileSizes.resize(N: numLoops, NV: zero);
121 return {tileSizes, numThreads};
122 }
123
124 // Compute the tile sizes from the iteration domain and number
125 // of tiles as follows
126 // - niters = ceilDiv(ub - lb, step)
127 // - tileSize = ceilDiv(niters, numThreads)
128 AffineExpr s0, s1, s2;
129 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
130 // TODO: The step here is assumed to be 1.
131 AffineExpr numItersExpr = (s1 - s0);
132 AffineExpr tileSizeExpr = numItersExpr.ceilDiv(other: s2);
133 tileSizes.resize(N: numLoops, NV: zero);
134 for (auto [index, range, nt] :
135 llvm::enumerate(First&: iterationDomain, Rest&: numThreads)) {
136 if (isZeroInteger(v: nt))
137 continue;
138
139 tileSizes[index] = affine::makeComposedFoldedAffineApply(
140 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
141 }
142 tileSizes.resize(N: numLoops, NV: zero);
143 return {tileSizes, numThreads};
144 }
145
146 // Enforce the convention that "tiling by zero"
147 // skips tiling a particular dimension. This convention is significantly
148 // simpler to handle instead of adjusting affine maps to account for missing
149 // dimensions.
150 assert(options.tileSizeComputationFunction &&
151 "expected tile sizes to be specified");
152 tileSizes = options.tileSizeComputationFunction(rewriter, op);
153 tileSizes.resize(N: numLoops, NV: zero);
154
155 return {tileSizes, numThreads};
156}
157
158/// Checks if any of the tiled loops are not parallel.
159static void checkSafeToTileToForall(TilingInterface op,
160 ArrayRef<OpFoldResult> tileSizes,
161 ArrayRef<OpFoldResult> numThreads) {
162 auto iterators = op.getLoopIteratorTypes();
163 assert(iterators.size() == tileSizes.size() &&
164 "expected as many tile size values as number of loops");
165 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166 "when specified, expected number of threads to use for each loop");
167
168 for (auto [index, iterator, tileSize] :
169 llvm::enumerate(iterators, tileSizes)) {
170 // If num threads is specified, check that it is greater than one only for
171 // parallel dimensions.
172 if (!numThreads.empty()) {
173 if (std::optional<int64_t> constNumThreads =
174 getConstantIntValue(numThreads[index])) {
175 if (constNumThreads.value() > 1 &&
176 iterator != utils::IteratorType::parallel) {
177 op.emitWarning() << "tiling is not thread safe at axis #" << index;
178 }
179 }
180 continue;
181 }
182
183 if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
184 if (constTileSize.value() > 0 &&
185 iterator != utils::IteratorType::parallel) {
186 op.emitWarning() << "tiling is not thread safe at axis #" << index;
187 }
188 }
189 }
190}
191
192/// Check if `stride` evenly divides the trip count `size - offset`.
193static bool tileDividesIterationDomain(Range loopRange) {
194 std::optional<int64_t> offsetAsInt = getConstantIntValue(ofr: loopRange.offset);
195 if (!offsetAsInt)
196 return false;
197 std::optional<int64_t> sizeAsInt = getConstantIntValue(ofr: loopRange.size);
198 if (!sizeAsInt)
199 return false;
200 std::optional<int64_t> strideAsInt = getConstantIntValue(ofr: loopRange.stride);
201 if (!strideAsInt)
202 return false;
203 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
204}
205
206/// Returns the bounded tile size given the current `offset`, `loopRange` and
207/// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
208static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
209 Range loopRange, OpFoldResult offset,
210 OpFoldResult tileSize) {
211 std::optional<int64_t> ts = getConstantIntValue(ofr: tileSize);
212 if (ts && ts.value() == 1)
213 return tileSize;
214
215 if (tileDividesIterationDomain(
216 loopRange: Range{.offset: loopRange.offset, .size: loopRange.size, .stride: tileSize}))
217 return tileSize;
218
219 // The tile size to use (to avoid out of bounds access) is minimum of
220 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
221 // loop.
222 AffineExpr s0, s1, d0;
223 bindDims(ctx: b.getContext(), exprs&: d0);
224 bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1);
225 AffineMap minMap = AffineMap::get(dimCount: 1, symbolCount: 2, results: {s0 - d0, s1}, context: b.getContext());
226 Value size = getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange.size);
227 return affine::makeComposedFoldedAffineMin(
228 b, loc, map: minMap, operands: SmallVector<OpFoldResult>{offset, size, tileSize});
229}
230
231/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
232/// than `iterationSize`.
233static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
234 OpFoldResult numThreads,
235 OpFoldResult iterationSize) {
236 std::optional<int64_t> tileSizeConst = getConstantIntValue(ofr: tileSize);
237 std::optional<int64_t> numThreadsConst = getConstantIntValue(ofr: numThreads);
238 std::optional<int64_t> iterSizeConst = getConstantIntValue(ofr: iterationSize);
239 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
240 return false;
241 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
242}
243
244/// Compute the `OpFoldResult`s that represents the multi-dimensional
245/// `offset`s and `size`s of the tile of the iteration space that the
246/// innermost loop body of the generated tiled loops corresponds to.
247static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
248getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
249 ArrayRef<Range> iterationDomain,
250 ArrayRef<OpFoldResult> tileSizes,
251 ArrayRef<OpFoldResult> numThreads) {
252 SmallVector<OpFoldResult> offsets, sizes;
253 int materializedLoopNum = 0;
254
255 if (!numThreads.empty()) {
256 AffineExpr d0, d1, s0, s1;
257 AffineExpr offsetExpr, residualTileSizeExpr;
258 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1);
259 bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1);
260 offsetExpr = d0 + d1 * s0;
261 residualTileSizeExpr = s1 - (d0 + d1 * s0);
262
263 for (auto [nt, tileSize, loopRange] :
264 llvm::zip_equal(t&: numThreads, u&: tileSizes, args&: iterationDomain)) {
265
266 // Non-tiled cases, set the offset and size to the
267 // `loopRange.offset/size`.
268 if (isZeroInteger(v: nt)) {
269 offsets.push_back(Elt: loopRange.offset);
270 sizes.push_back(Elt: loopRange.size);
271 continue;
272 }
273
274 Value iv = ivs[materializedLoopNum++];
275 OpFoldResult offset = affine::makeComposedFoldedAffineApply(
276 b&: rewriter, loc, expr: offsetExpr,
277 operands: ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
278 OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
279 b&: rewriter, loc, expr: residualTileSizeExpr,
280 operands: {loopRange.offset, nt, tileSize, loopRange.size});
281
282 OpFoldResult size = tileSize;
283 if (!isZeroInteger(v: residualTileSize)) {
284 OpFoldResult sizeMinusOffsetPerThread =
285 affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 - d0,
286 operands: {offset, loopRange.size});
287 size = affine::makeComposedFoldedAffineMin(
288 b&: rewriter, loc,
289 map: AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext()),
290 operands: {sizeMinusOffsetPerThread, tileSize});
291 }
292
293 // Consider the case where the original loop was `[0, 100)`.
294 // If number of threads are `7`, the tile size would be computed as
295 // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
296 // - `offset = 0 + 6 * 15 = 105`
297 // - `tileSize = min(15, 100 - 105) = -5`
298 // To avoid negative tile sizes, we need to do a further
299 // `nonNegativeTileSize = affine.max(0, tileSize)`.
300 // This `max` can be avoided if
301 // `offset + tileSize * (numThreads - 1) < (ub - lb)`
302 if (!canOmitTileOffsetInBoundsCheck(tileSize, numThreads: nt, iterationSize: loopRange.size)) {
303 AffineMap maxMap =
304 AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext());
305 size = affine::makeComposedFoldedAffineMax(
306 rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
307 }
308
309 offsets.push_back(Elt: offset);
310 sizes.push_back(Elt: size);
311 }
312 return {offsets, sizes};
313 } else {
314 for (auto [tileSize, loopRange] :
315 llvm::zip_equal(t&: tileSizes, u&: iterationDomain)) {
316
317 // Non-tiled cases, set the offset and size to the
318 // `loopRange.offset/size`.
319 if (isZeroInteger(v: tileSize)) {
320 offsets.push_back(Elt: loopRange.offset);
321 sizes.push_back(Elt: loopRange.size);
322 continue;
323 }
324
325 Value iv = ivs[materializedLoopNum++];
326 OpFoldResult offset = getAsOpFoldResult(val: iv);
327 offsets.push_back(Elt: offset);
328 OpFoldResult size =
329 getBoundedTileSize(b&: rewriter, loc, loopRange, offset, tileSize);
330 sizes.push_back(Elt: size);
331 }
332 return {offsets, sizes};
333 }
334}
335
336/// Function to return the bounds of the loops to be generated.
337static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
338 SmallVector<OpFoldResult>>
339getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
340 ArrayRef<OpFoldResult> tileSizes) {
341 SmallVector<OpFoldResult> lbs, ubs, steps;
342 for (auto [loopRange, tileSize] : llvm::zip_equal(t&: loopRanges, u&: tileSizes)) {
343 // No loop if the tile size is 0.
344 if (isZeroInteger(v: tileSize))
345 continue;
346 lbs.push_back(Elt: loopRange.offset);
347 ubs.push_back(Elt: loopRange.size);
348 steps.push_back(Elt: tileSize);
349 }
350 return {lbs, ubs, steps};
351}
352
353/// A function that allows returning additional yielded values during
354/// `yieldTiledValuesAndReplace`.
355/// - `ivs` induction variable for the loop.
356/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
357/// - `tiledValues` the tiled values to return. Must be of same size as
358/// `newbbArgs`, each element of this array is inserted into the corresponding
359/// element in `newbbArgs`.
360/// - `resultOffsets` is of the same size as `tiledValues` and represents
361/// the offsets to use when inserting corresponding element from `tiledValues`
362/// into the element from `newBbArgs`.
363/// - `resultSizes` is of the same size as `tiledValues` and represents
364/// the size of the corresponding element from `tiledValues` inserted into
365/// the element from `newBbArgs`.
366/// In case the method needs to return `failure()` the method is expected
367/// to clean up any inserted operations.
368using YieldTiledValuesFn = std::function<LogicalResult(
369 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
370 SmallVector<Value> &tiledValues,
371 SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
372 SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
373
374/// Clones the operation and updates the destination if the operation
375/// implements the `DestinationStyleOpInterface`.
376static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
377 Operation *op,
378 ValueRange newDestArgs) {
379 Operation *clonedOp = rewriter.clone(op&: *op);
380 if (newDestArgs.empty())
381 return clonedOp;
382 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
384 return clonedOp;
385}
386
387/// Generate the tile-loop nest using `scf.for` operation.
388/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
389/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
390/// - `destinationTensors` are the init values to use for the outer most loop.
391/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
392/// most
393/// loop.
394/// - `loops` is an in-out parameter into which the generated loops are
395/// populated.
396static LogicalResult generateLoopNestUsingForOp(
397 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
398 ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
399 YieldTiledValuesFn yieldTiledValuesFn,
400 SmallVector<LoopLikeOpInterface> &loops) {
401 assert(!loopRanges.empty() && "unexpected empty loop ranges");
402 assert(loopRanges.size() == tileSizes.size() &&
403 "expected as many tile sizes as loop ranges");
404 OpBuilder::InsertionGuard guard(rewriter);
405
406 SmallVector<OpFoldResult> lbs, ubs, steps;
407 std::tie(args&: lbs, args&: ubs, args&: steps) =
408 getLoopBounds(rewriter, loc, loopRanges, tileSizes);
409 SmallVector<Value> lbVals =
410 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: lbs);
411 SmallVector<Value> ubVals =
412 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: ubs);
413 SmallVector<Value> stepVals =
414 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: steps);
415
416 SmallVector<Value> ivs;
417 for (auto [lb, ub, step] : llvm::zip_equal(t&: lbVals, u&: ubVals, args&: stepVals)) {
418 auto loop =
419 rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
420 [](OpBuilder &bodyBuilder, Location bodyLoc,
421 Value iv, ValueRange /*iterArgs*/) {});
422 loops.push_back(loop);
423 ivs.push_back(Elt: loop.getInductionVar());
424 rewriter.setInsertionPointToEnd(loop.getBody());
425 destinationTensors = loop.getRegionIterArgs();
426 }
427
428 SmallVector<Value> tiledResults;
429 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
430 if (failed(Result: yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431 tiledResults, resultOffsets, resultSizes))) {
432 return rewriter.notifyMatchFailure(
433 arg&: loc, msg: "failed to generate inner tile loop body");
434 }
435 if (loops.empty())
436 return success();
437
438 assert(tiledResults.size() == destinationTensors.size() &&
439 "Number of results of body should be equal to number of iter args");
440
441 // 6. Yield all the results of the tiled operation.
442 SmallVector<Value> yieldedValues;
443 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444 llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets,
445 args&: resultSizes)) {
446 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
447 rewriter.getIndexAttr(1));
448 auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
449 loc, tiledValue, destinationTensor, resultOffset, resultSize,
450 resultStride);
451 yieldedValues.push_back(Elt: insertSlice);
452 }
453 rewriter.create<scf::YieldOp>(loc, yieldedValues);
454
455 // Add the scf.yield operations for all the outer loops.
456 for (auto [outerLoop, innerLoop] :
457 llvm::zip_equal(MutableArrayRef(loops).drop_back(),
458 MutableArrayRef(loops).drop_front())) {
459 rewriter.setInsertionPointToEnd(
460 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
461 rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
462 }
463 return success();
464}
465
466/// Generate the tile-loop nest using `scf.forall` operation.
467/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
468/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
469/// - `destinationTensors` are the init values to use for the outer most loop.
470/// - `mappingVector` is the mapping attributes to use for loop construction.
471/// Can be empty.
472/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
473/// most
474/// loop.
475/// - `loops` is an in-out parameter into which the generated loops are
476/// populated.
477static LogicalResult generateLoopNestUsingForallOp(
478 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
479 ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
480 ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
481 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
482 assert(!loopRanges.empty() && "unexpected empty loop ranges");
483 assert(loopRanges.size() == tileSizes.size() &&
484 "expected as many tile sizes as loop ranges");
485 OpBuilder::InsertionGuard guard(rewriter);
486
487 std::optional<ArrayAttr> mappingAttr;
488 if (!mappingVector.empty())
489 mappingAttr = rewriter.getArrayAttr(mappingVector);
490
491 scf::ForallOp forallOp;
492 bool useNumThreads = !numThreads.empty();
493
494 if (useNumThreads) {
495 // Prune the zero numthreads.
496 SmallVector<OpFoldResult> nonZeroNumThreads;
497 for (auto nt : numThreads) {
498 if (isZeroInteger(v: nt))
499 continue;
500 nonZeroNumThreads.push_back(Elt: nt);
501 }
502 forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
503 destinationTensors, mappingAttr);
504 } else {
505 SmallVector<OpFoldResult> lbs, ubs, steps;
506 std::tie(args&: lbs, args&: ubs, args&: steps) =
507 getLoopBounds(rewriter, loc, loopRanges, tileSizes);
508 forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
509 destinationTensors, mappingAttr);
510 }
511 loops.push_back(forallOp);
512
513 rewriter.setInsertionPoint(forallOp.getTerminator());
514 destinationTensors = forallOp.getRegionOutArgs();
515
516 SmallVector<Value> tiledResults;
517 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
518 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
519 destinationTensors, tiledResults, resultOffsets,
520 resultSizes)))
521 return rewriter.notifyMatchFailure(arg&: loc, msg: "failed to generate loop body");
522
523 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
524 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
525 llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets,
526 args&: resultSizes)) {
527 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
528 rewriter.getIndexAttr(1));
529
530 rewriter.create<tensor::ParallelInsertSliceOp>(
531 loc, tiledValue, destinationTensor, resultOffset, resultSize,
532 resultStride);
533 }
534 return success();
535}
536
537/// Generate the tile-loop nest using the loop construct specifed in `options`.
538/// - `options`: Tiling options specified.
539/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
540/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
541/// - `destinationTensors` are the init values to use for the outer most loop.
542/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
543/// most
544/// loop.
545/// - `loops` is an in-out parameter into which the generated loops are
546/// populated.
547static LogicalResult generateLoopNest(
548 RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
549 ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
550 ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
551 YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
552 // If the tile sizes are all zero, no loops are generated. Just call the
553 // callback function to handle untiled case.
554 if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) {
555 SmallVector<Value> tiledResults;
556 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
557 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
558 tiledResults, resultOffsets, resultSizes);
559 }
560 if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
561 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
562 destinationTensors, tiledBodyFn, loops);
563 }
564 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
565 return generateLoopNestUsingForallOp(
566 rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
567 destinationTensors, tiledBodyFn, loops);
568 }
569 return rewriter.notifyMatchFailure(arg&: loc, msg: "unhandled loop type");
570}
571
572static FailureOr<SmallVector<Value>>
573createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
574 ArrayRef<OpFoldResult> tileSizes,
575 const scf::SCFTilingOptions &options) {
576 SmallVector<Value> initTensors;
577 Location loc = op->getLoc();
578 switch (options.reductionStrategy) {
579 case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
580 if (failed(tensor::getOrCreateDestinations(b&: rewriter, loc, op: op, result&: initTensors)))
581 return failure();
582 return initTensors;
583 case scf::SCFTilingOptions::ReductionTilingStrategy::
584 PartialReductionOuterReduction: {
585 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
586 if (!redOp) {
587 return rewriter.notifyMatchFailure(
588 op, "PartialReductionOuterReduction tiling strategy is only supported"
589 "for operations implementing PartialReductionOpInterface");
590 }
591 // Get reduction dimensions.
592 // TODO: PartialReductionOpInterface should really query TilingInterface
593 // itself and find reduction dimensions.
594 SmallVector<int> reductionDims;
595 for (auto [idx, iteratorType] :
596 llvm::enumerate(op.getLoopIteratorTypes())) {
597 if (iteratorType == utils::IteratorType::reduction)
598 reductionDims.push_back(idx);
599 }
600 return redOp.generateInitialTensorForPartialReduction(
601 rewriter, loc, tileSizes, reductionDims);
602 }
603 default:
604 return rewriter.notifyMatchFailure(op,
605 "unhandled reduction tiling strategy");
606 }
607}
608
609static FailureOr<TilingResult>
610getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
611 ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
612 ArrayRef<OpFoldResult> sizes,
613 const scf::SCFTilingOptions &options) {
614 switch (options.reductionStrategy) {
615 case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
616 return op.getTiledImplementation(rewriter, offsets, sizes);
617 case scf::SCFTilingOptions::ReductionTilingStrategy::
618 PartialReductionOuterReduction: {
619 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
620 if (!redOp) {
621 return rewriter.notifyMatchFailure(
622 op, "PartialReductionOuterReduction tiling strategy is only "
623 "supported for operations "
624 "implementing PartialReductionOpInterface");
625 }
626 // Get reduction dimensions.
627 // TODO: PartialReductionOpInterface should really query TilingInterface
628 // itself and find reduction dimensions.
629 SmallVector<int> reductionDims;
630 for (auto [idx, iteratorType] :
631 llvm::enumerate(op.getLoopIteratorTypes())) {
632 if (iteratorType == utils::IteratorType::reduction)
633 reductionDims.push_back(idx);
634 }
635 return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
636 offsets, sizes, reductionDims);
637 }
638 default:
639 return rewriter.notifyMatchFailure(op,
640 "unhandled reduction tiling strategy");
641 }
642}
643
644static LogicalResult
645getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
646 TilingInterface op, ArrayRef<OpFoldResult> offsets,
647 ArrayRef<OpFoldResult> sizes,
648 SmallVector<OpFoldResult> &resultOffset,
649 SmallVector<OpFoldResult> &resultSize,
650 const scf::SCFTilingOptions &options) {
651
652 switch (options.reductionStrategy) {
653 case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
654 return op.getResultTilePosition(rewriter, index, offsets, sizes,
655 resultOffset, resultSize);
656 case scf::SCFTilingOptions::ReductionTilingStrategy::
657 PartialReductionOuterReduction: {
658 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
659 if (!redOp) {
660 return rewriter.notifyMatchFailure(
661 op, "PartialReductionOuterReduction tiling strategy is only supported"
662 "for operations implementing PartialReductionOpInterface");
663 }
664 // Get reduction dimensions.
665 // TODO: PartialReductionOpInterface should really query TilingInterface
666 // itself and find reduction dimensions.
667 SmallVector<int> reductionDims;
668 for (auto [idx, iteratorType] :
669 llvm::enumerate(op.getLoopIteratorTypes())) {
670 if (iteratorType == utils::IteratorType::reduction)
671 reductionDims.push_back(idx);
672 }
673 return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
674 resultOffset, resultSize,
675 reductionDims);
676 }
677 default:
678 return rewriter.notifyMatchFailure(op,
679 "unhandled reduction tiling strategy");
680 }
681}
682
683static FailureOr<MergeResult>
684mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
685 ValueRange partialResults,
686 const scf::SCFTilingOptions &options) {
687 switch (options.reductionStrategy) {
688 case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
689 // No need to merge results for reduction tiling strategy.
690 return MergeResult{.mergeOps: {}, .replacements: partialResults};
691 case scf::SCFTilingOptions::ReductionTilingStrategy::
692 PartialReductionOuterReduction: {
693 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
694 if (!redOp) {
695 return rewriter.notifyMatchFailure(
696 op, "PartialReductionOuterReduction tiling strategy is only "
697 "supported for operations "
698 "implementing PartialReductionOpInterface");
699 }
700 // Get reduction dimensions.
701 // TODO: PartialReductionOpInterface should really query TilingInterface
702 // itself and find reduction dimensions.
703 SmallVector<int> reductionDims;
704 for (auto [idx, iteratorType] :
705 llvm::enumerate(op.getLoopIteratorTypes())) {
706 if (iteratorType == utils::IteratorType::reduction)
707 reductionDims.push_back(idx);
708 }
709 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
710 reductionDims);
711 }
712 default:
713 return rewriter.notifyMatchFailure(op,
714 "unhandled reduction tiling strategy");
715 }
716}
717
718/// Append the specified additional `newInitOperands` operands to the
719/// loops existing `init` operands (or similar), and replace `loopOp` with
720/// the new loop that has the additional init operands. The loop body of
721/// this loop is moved over to the new loop. `yieldTiledValuesFn`
722/// is called to get the new tiled values returned, and the offset
723/// and sizes at which the tiled value is inserted into the
724/// new region iter_args that correspond to the newly added init operands.
725template <typename LoopType>
726FailureOr<LoopLikeOpInterface>
727yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
728 ValueRange newInitOperands,
729 YieldTiledValuesFn yieldTiledValuesFn) {
730 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
731}
732
733/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
734template <>
735FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
736 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
737 YieldTiledValuesFn yieldTiledValuesFn) {
738 OpBuilder::InsertionGuard g(rewriter);
739 Location loc = loopOp.getLoc();
740 rewriter.setInsertionPoint(loopOp);
741
742 auto inits = llvm::to_vector(loopOp.getInitArgs());
743 inits.append(newInitOperands.begin(), newInitOperands.end());
744 auto newLoop = rewriter.create<scf::ForOp>(
745 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
746 inits, [](OpBuilder &, Location, Value, ValueRange) {});
747
748 // Move the loop body to the new op.
749 Block *loopBody = loopOp.getBody();
750 Block *newLoopBody = newLoop.getBody();
751 rewriter.mergeBlocks(
752 loopBody, newLoopBody,
753 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
754
755 auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
756 rewriter.setInsertionPoint(yieldOp);
757
758 SmallVector<Value> tiledValues;
759 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
760 ValueRange newRegionIterArgs =
761 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
762 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
763 newRegionIterArgs, tiledValues, resultOffsets,
764 resultSizes))) {
765 rewriter.eraseOp(newLoop);
766 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
767 }
768
769 SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
770 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
771 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
772 resultSizes)) {
773 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
774 rewriter.getIndexAttr(1));
775 Value insert = rewriter.create<tensor::InsertSliceOp>(
776 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
777 resultStride);
778 newYieldValues.push_back(insert);
779 }
780
781 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
782 rewriter.replaceOp(loopOp,
783 newLoop->getResults().take_front(loopOp.getNumResults()));
784 return cast<LoopLikeOpInterface>(newLoop.getOperation());
785}
786
787/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
788template <>
789FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
790 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
791 YieldTiledValuesFn yieldTiledValuesFn) {
792 OpBuilder::InsertionGuard g(rewriter);
793 Location loc = loopOp.getLoc();
794 rewriter.setInsertionPoint(loopOp);
795 auto inits = llvm::to_vector(loopOp.getOutputs());
796 inits.append(newInitOperands.begin(), newInitOperands.end());
797 auto newLoop = rewriter.create<scf::ForallOp>(
798 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
799 loopOp.getMixedStep(), inits, loopOp.getMapping(),
800 [](OpBuilder &, Location, ValueRange) {});
801
802 // Move the region of the current block to the newly created op.
803 Block *loopBody = loopOp.getBody();
804 Block *newLoopBody = newLoop.getBody();
805 rewriter.mergeBlocks(
806 loopBody, newLoopBody,
807 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
808
809 auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
810 rewriter.setInsertionPoint(terminator);
811 SmallVector<Value> tiledValues;
812 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
813 ValueRange regionIterArgs =
814 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
815 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
816 regionIterArgs, tiledValues, resultOffsets,
817 resultSizes))) {
818 rewriter.eraseOp(newLoop);
819 return rewriter.notifyMatchFailure(loopOp,
820 "failed to get yielded tiled values");
821 }
822
823 // Update the terminator.
824 rewriter.setInsertionPointToEnd(terminator.getBody());
825
826 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
827 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
828 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
829 rewriter.getIndexAttr(1));
830 rewriter.create<tensor::ParallelInsertSliceOp>(
831 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
832 resultStride);
833 }
834
835 rewriter.replaceOp(loopOp,
836 newLoop->getResults().take_front(loopOp.getNumResults()));
837 return cast<LoopLikeOpInterface>(newLoop.getOperation());
838}
839
840/// Implementation of `yieldTiledValuesAndReplaceLoop` for
841/// `LoopLikeOpInterface`, that just dispatches to the implementation for each
842/// supported loop type.
843FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
844 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
845 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
846 return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
847 loopLikeOp.getOperation())
848 .Case<scf::ForOp, scf::ForallOp>(
849 [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
850 return yieldTiledValuesAndReplaceLoop(
851 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
852 })
853 .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
854 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
855 });
856}
857
858/// Method to add new init values to a loop nest. Updates `loops` in-place
859/// with new loops that use the `newInitValues`. The outer-loops are updated
860/// to yield the new result values of the inner loop. For the innermost loop,
861/// the call back `getNewYields` is invoked to get the additional values to
862/// yield form the innermost loop.
863static LogicalResult addInitOperandsToLoopNest(
864 RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
865 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
866 if (loops.empty())
867 return success();
868 OpBuilder::InsertionGuard g(rewriter);
869 rewriter.setInsertionPoint(loops.front());
870
871 SmallVector<Value> ivs;
872 for (auto &loop : loops.drop_back()) {
873 rewriter.setInsertionPoint(loop);
874
875 // if loops.size() > 1 we assume that scf.for is used for the loops.
876 auto forLoop = cast<scf::ForOp>(loop.getOperation());
877
878 // Create a new loop with the new init values for this loop.
879 SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
880 newInits.append(newInitValues.begin(), newInitValues.end());
881 auto newLoop = rewriter.create<scf::ForOp>(
882 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
883 forLoop.getStep(), newInits,
884 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
885
886 // Merge the body of the new loop with the body of the old loops.
887 SmallVector<Value> sourceBlockArgs;
888 sourceBlockArgs.push_back(newLoop.getInductionVar());
889 auto newRegionIterArgs = newLoop.getRegionIterArgs();
890 sourceBlockArgs.append(
891 newRegionIterArgs.begin(),
892 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
893 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
894 rewriter.replaceOp(
895 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
896 loop = newLoop;
897 ivs.push_back(newLoop.getInductionVar());
898 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
899 }
900
901 // Update the loop body of the innermost loop to get new yield values.
902 LoopLikeOpInterface innerMostLoop = loops.back();
903 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
904 yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
905 getNewTiledYieldsFn);
906
907 if (failed(newInnerMostLoop))
908 return innerMostLoop.emitOpError("failed to return additional yields");
909 loops.back() = newInnerMostLoop.value();
910
911 // Make all other loops except the innermost loops yield the values returned
912 // by the inner loop.
913 for (auto [outerLoop, innerLoop] :
914 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
915 // Again assume that all the outer loops are scf.for operations.
916 auto outerForLoop = cast<scf::ForOp>(outerLoop);
917 auto outerLoopYield =
918 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
919 SmallVector<Value> newYields =
920 llvm::to_vector(outerLoopYield.getOperands());
921 ValueRange additionalYields =
922 innerLoop->getResults().take_back(newInitValues.size());
923 newYields.append(additionalYields.begin(), additionalYields.end());
924 rewriter.setInsertionPoint(outerLoopYield);
925 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
926 }
927 return success();
928}
929
930/// Implementation of tiling transformation of `op` that implements the
931/// `TilingInterface` using `scf.for` to iterate over the tiles.
932FailureOr<scf::SCFTilingResult>
933mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
934 const scf::SCFTilingOptions &options) {
935 if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
936 return failure();
937 }
938
939 OpBuilder::InsertionGuard guard(rewriter);
940 rewriter.setInsertionPointAfter(op);
941
942 // 1. Get the range of the loops that are represented by the operation.
943 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
944
945 // 2. Materialize the tile sizes and/or number of threads;
946 SmallVector<OpFoldResult> tileSizes, numThreads;
947 std::tie(args&: tileSizes, args&: numThreads) =
948 getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
949
950 // Check if it is safe to tile. This is hold over from previous iterations
951 // of tile to for-all. Consider dropping it.
952 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
953 checkSafeToTileToForall(op, tileSizes, numThreads);
954 }
955
956 // 3. If there is an interchange specified, permute the iteration domain and
957 // the tile sizes.
958 SmallVector<int64_t> interchangeVector;
959 if (!options.interchangeVector.empty()) {
960 interchangeVector = fillInterchangeVector(interchangeVector: options.interchangeVector,
961 iterationDomainSize: iterationDomain.size());
962 assert(isPermutationVector(interchangeVector) &&
963 "expected interchange vector to be a permutation");
964
965 applyPermutationToVector(inVec&: iterationDomain, permutation: interchangeVector);
966 applyPermutationToVector(inVec&: tileSizes, permutation: interchangeVector);
967 if (!numThreads.empty())
968 applyPermutationToVector(inVec&: numThreads, permutation: interchangeVector);
969 }
970
971 FailureOr<TilingResult> tilingResult;
972 // 4. Define the lambda function used later to generate the body of the
973 // innermost tiled loop.
974 YieldTiledValuesFn innerYieldTiledValuesFn =
975 [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
976 ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
977 SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
978 SmallVector<SmallVector<OpFoldResult>> &resultSizes)
979 -> LogicalResult {
980 // 4a. Compute the `offsets` and `sizes` to use for tiling.
981 SmallVector<OpFoldResult> offsets, sizes;
982 std::tie(args&: offsets, args&: sizes) = getTileOffsetAndSizes(
983 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
984
985 // 4b. If interchange was provided, apply inverse of the interchange
986 // to get back the offsets/sizes in the order to be specified.
987 if (!interchangeVector.empty()) {
988 auto inversePermutation = invertPermutationVector(permutation: interchangeVector);
989 applyPermutationToVector(inVec&: offsets, permutation: inversePermutation);
990 applyPermutationToVector(inVec&: sizes, permutation: inversePermutation);
991 }
992
993 // 5. Generate the tiled implementation within the inner most loop.
994
995 // 5a. Clone the operation within the loop body.
996 auto clonedOp = cast<TilingInterface>(
997 cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
998
999 // 5b. Early return cloned op if tiling is not happening. We can not
1000 // return the original op because it could lead to `rewriter.replaceOp(op,
1001 // op->getResults())` and users would get crash.
1002 if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) {
1003 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1004 tilingResult =
1005 TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1006 /*generatedSlices=*/{}};
1007 return success();
1008 }
1009
1010 // 5c. Tile the cloned operation.
1011 tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
1012 offsets, sizes, options);
1013 if (failed(Result: tilingResult)) {
1014 rewriter.eraseOp(op: clonedOp);
1015 return op.emitOpError("faild to tile operation");
1016 }
1017
1018 // 5d. Delete the cloned operation.
1019 rewriter.eraseOp(op: clonedOp);
1020
1021 // 5e. Compute the offsets at which the result values are to be inserted
1022 // back into its destinations.
1023 for (auto [index, tiledValue] :
1024 llvm::enumerate(First&: tilingResult->tiledValues)) {
1025 tiledResults.push_back(Elt: tiledValue);
1026 SmallVector<OpFoldResult> resultOffset, resultSize;
1027 if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
1028 sizes, resultOffset, resultSize,
1029 options))) {
1030 for (auto op : tilingResult->tiledOps) {
1031 rewriter.eraseOp(op);
1032 }
1033 return rewriter.notifyMatchFailure(
1034 op, "failed to get slice of result produced");
1035 }
1036 resultOffsets.emplace_back(Args: std::move(resultOffset));
1037 resultSizes.emplace_back(Args: std::move(resultSize));
1038 }
1039
1040 return success();
1041 };
1042
1043 // 6. Find the destination tensors to use for the operation.
1044 FailureOr<SmallVector<Value>> maybeInits =
1045 createInitialTensorsForTiling(rewriter, op, tileSizes, options);
1046 if (failed(Result: maybeInits)) {
1047 return rewriter.notifyMatchFailure(
1048 op, "unable to create initial tensors for tiling");
1049 }
1050 SmallVector<Value> &initTensors = maybeInits.value();
1051
1052 // 7. Generate the tiled loops nest using the callback defined above.
1053 SmallVector<LoopLikeOpInterface> loops;
1054 if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
1055 tileSizes, numThreads, initTensors,
1056 innerYieldTiledValuesFn, loops)))
1057 return op.emitOpError("failed to generate tiling loops");
1058 assert(succeeded(tilingResult) &&
1059 "expected tiling result to be computed after loop generation");
1060
1061 if (loops.empty()) {
1062 // If loops are empty, the tiled op is used as the replacement for the
1063 // untiled op.
1064 return scf::SCFTilingResult{tilingResult->tiledOps,
1065 initTensors,
1066 loops,
1067 tilingResult->tiledValues,
1068 tilingResult->generatedSlices,
1069 {}};
1070 }
1071
1072 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1073 [](OpResult r) -> Value { return r; });
1074
1075 // For the full reduction case, there is nothing more to do.
1076 if (options.reductionStrategy ==
1077 scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) {
1078 return scf::SCFTilingResult{
1079 tilingResult->tiledOps, initTensors, loops, loopResults,
1080 tilingResult->generatedSlices, {}};
1081 }
1082
1083 // The results of the loop needs to be merged.
1084 FailureOr<MergeResult> mergeResult =
1085 mergeTilingResults(rewriter, op, loopResults, options);
1086 if (failed(Result: mergeResult)) {
1087 return rewriter.notifyMatchFailure(
1088 op, "Failed to merge partial results from tiling");
1089 }
1090 return scf::SCFTilingResult{tilingResult->tiledOps,
1091 initTensors,
1092 loops,
1093 mergeResult->replacements,
1094 tilingResult->generatedSlices,
1095 mergeResult->mergeOps};
1096}
1097
1098FailureOr<scf::SCFTilingResult>
1099mlir::scf::tileReductionUsingScf(RewriterBase &b,
1100 PartialReductionOpInterface op,
1101 ArrayRef<OpFoldResult> tileSize) {
1102 scf::SCFTilingOptions options;
1103 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1104 options.setReductionTilingStrategy(
1105 scf::SCFTilingOptions::ReductionTilingStrategy::
1106 PartialReductionOuterReduction);
1107 options.setTileSizes(tileSize);
1108 return tileUsingSCF(b, op, options);
1109}
1110
1111//===----------------------------------------------------------------------===//
1112// tileConsumerAndFuseProducersUsingSCF implementation.
1113//===----------------------------------------------------------------------===//
1114
1115/// Return the untiled producer whose slice is used in a tiled consumer. The
1116/// method traverses the tile loop nest (`loops`) if needed, and returns the
1117/// `iter_args` of the outer most that is encountered. Traversing the
1118/// iter_args indicates that this is a destination operand of the consumer. If
1119/// there was no loop traversal needed, the second value of the returned tuple
1120/// is empty.
1121static std::tuple<OpResult, std::optional<OpOperand *>>
1122getUntiledProducerFromSliceSource(OpOperand *source,
1123 ArrayRef<LoopLikeOpInterface> loops) {
1124 std::optional<OpOperand *> destinationIterArg;
1125 assert(!loops.empty() && "expected non empty loops container");
1126 auto loopIt = loops.rbegin();
1127 while (loopIt != loops.rend() && isa<BlockArgument>(Val: source->get())) {
1128 auto iterArg = cast<BlockArgument>(Val: source->get());
1129 auto loop = *loopIt;
1130 if (iterArg.getOwner()->getParentOp() != loop)
1131 break;
1132 source = loop.getTiedLoopInit(iterArg);
1133 loopIt++;
1134 }
1135 if (loopIt == loops.rend())
1136 destinationIterArg = source;
1137 return {dyn_cast<OpResult>(Val: source->get()), destinationIterArg};
1138}
1139
1140/// Implementation of fusing producer of a single slice by computing the
1141/// slice of the producer in-place.
1142std::optional<scf::SCFFuseProducerOfSliceResult>
1143mlir::scf::tileAndFuseProducerOfSlice(
1144 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1145 MutableArrayRef<LoopLikeOpInterface> loops) {
1146 // 1. Get the producer of the source (potentially walking through
1147 // `iter_args` of nested `scf.for`)
1148 auto [fusableProducer, destinationInitArg] =
1149 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1150 loops);
1151 if (!fusableProducer)
1152 return std::nullopt;
1153 unsigned resultNumber = fusableProducer.getResultNumber();
1154
1155 OpBuilder::InsertionGuard g(rewriter);
1156 rewriter.setInsertionPoint(candidateSliceOp);
1157
1158 // 2. Clone the fused producer
1159 // 2a. Compute the destination operands to use for the cloned operation.
1160 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1161 Operation *fusableProducerOp = fusableProducer.getOwner();
1162 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1163 failed(tensor::getOrCreateDestinations(
1164 rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1165 origDestinationTensors)))
1166 return std::nullopt;
1167
1168 clonedOpDestinationTensors = origDestinationTensors;
1169 if (destinationInitArg &&
1170 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1171 // 2b. If the producer is also destination style, then to maintain the
1172 // destination passing style, update the destination of the producer to be
1173 // the source of the slice.
1174 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1175 }
1176 // 2c. Clone the fused producer.
1177 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1178 rewriter, op: fusableProducerOp, newDestArgs: clonedOpDestinationTensors);
1179 // 2d. Update the source of the candidateSlice to be the cloned producer.
1180 // Easier to just clone the slice with different source since
1181 // replacements and DCE of cloned ops becomes easier
1182 SmallVector<Value> candidateSliceOpOperands =
1183 llvm::to_vector(candidateSliceOp->getOperands());
1184 candidateSliceOpOperands[0] = clonedProducerOp->getResult(idx: resultNumber);
1185 tensor::ExtractSliceOp clonedCandidateSliceOp =
1186 mlir::clone(rewriter, candidateSliceOp,
1187 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1188
1189 // 3. Generate the tiled implementation of the producer of the source
1190 FailureOr<TilingResult> tileAndFuseResult =
1191 tensor::replaceExtractSliceWithTiledProducer(
1192 rewriter, clonedCandidateSliceOp,
1193 clonedProducerOp->getResult(idx: resultNumber));
1194 if (failed(Result: tileAndFuseResult))
1195 return std::nullopt;
1196 // Note: Do not delete the candidateSliceOp, since its passed in from the
1197 // caller.
1198 rewriter.replaceAllUsesWith(candidateSliceOp,
1199 tileAndFuseResult->tiledValues[0]);
1200 rewriter.eraseOp(op: clonedCandidateSliceOp);
1201 rewriter.eraseOp(op: clonedProducerOp);
1202
1203 // 3. If the slice is for a destination operand, for example,
1204 //
1205 // ```mlir
1206 // %0 = linalg.init
1207 // %1 = linalg.fill .. outs(%0 : )
1208 // %2 = scf.for .. iter_args(%arg0 = %1) {
1209 // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1210 // %4 = tensor.extract_slice %arg1 [..]
1211 // .. = linalg.matmul .. outs(%4 : )
1212 // }
1213 // }
1214 // ```
1215 //
1216 // the IR is currently
1217 //
1218 // ```
1219 // %0 = linalg.init
1220 // %1 = linalg.fill
1221 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1222 // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1223 // %4 = tensor.extract_slice %arg1[..]
1224 // %5 = linalg.fill .. outs(%4 : )
1225 // .. = linalg.matmul .. outs(%5 : )
1226 // }
1227 // }
1228 // ```
1229 //
1230 // The untiled `linalg.fill` is still used as the `init_value` since it
1231 // was originally a destination operand of the untiled `linalg.matmul`.
1232 // When fusing an operand that is a destination operand, the iter_arg of
1233 // the outer most loop should be changed to use the destination of the
1234 // fused operation. With this the IR will be.
1235 //
1236 // ```
1237 // %0 = linalg.init
1238 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1239 // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1240 // %3 = tensor.extract_slice %arg1[..]
1241 // %4 = linalg.fill .. outs(%3 : )
1242 // .. = linalg.matmul .. outs(%4 : )
1243 // }
1244 // }
1245 // ```
1246 if (destinationInitArg &&
1247 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1248 loops.front()
1249 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1250 .set(origDestinationTensors[resultNumber]);
1251 }
1252 return scf::SCFFuseProducerOfSliceResult{
1253 fusableProducer, tileAndFuseResult->tiledValues[0],
1254 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1255}
1256
1257/// Reconstruct the fused producer from within the tiled-and-fused code.
1258FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1259 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1260 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1261 MutableArrayRef<LoopLikeOpInterface> loops,
1262 ArrayRef<unsigned> yieldResultNumber) {
1263 if (loops.empty())
1264 return success();
1265
1266 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1267 *tiledOwner = fusedProducerInfo.tiledOps[0];
1268
1269 Location loc = originalOwner->getLoc();
1270 // a. collect all init Value to be appended
1271 SmallVector<unsigned> initNumberList =
1272 yieldResultNumber.empty() ? llvm::to_vector(Range: llvm::seq<unsigned>(
1273 Begin: 0, End: originalOwner->getNumResults()))
1274 : llvm::to_vector(Range&: yieldResultNumber);
1275 SmallVector<Value> initValueList;
1276 for (const auto &resultNumber : initNumberList) {
1277 FailureOr<Value> initValue = tensor::getOrCreateDestination(
1278 b&: rewriter, loc, opResult: originalOwner->getResult(idx: resultNumber));
1279 if (succeeded(Result: initValue)) {
1280 initValueList.push_back(Elt: initValue.value());
1281 } else {
1282 return failure();
1283 }
1284 }
1285
1286 SmallVector<Operation *> generatedSlices;
1287 YieldTiledValuesFn newYieldValuesFn =
1288 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1289 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1290 SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
1291 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1292 OpBuilder::InsertionGuard g(innerRewriter);
1293
1294 // get sliceOp tile information
1295 SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1296 sliceSizes = sliceOp.getMixedSizes();
1297
1298 // expect all strides of sliceOp being 1
1299 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
1300 return failure();
1301
1302 unsigned sliceResultNumber =
1303 fusedProducerInfo.origProducer.getResultNumber();
1304
1305 auto tilableOp = cast<TilingInterface>(originalOwner);
1306 // b. get iterDomain Offset and Sizes based on sliceOp tile
1307 SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1308 // skip tensor.pack/unpack/pad, which expects single opResult
1309 if (tilableOp->getNumResults() > 1 &&
1310 failed(tilableOp.getIterationDomainTileFromResultTile(
1311 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1312 iterDomainOffset, iterDomainSizes))) {
1313 // In theory, it is unnecessary to raise an error here. Actually
1314 // although it fails to reconstruct the result tensor, it should not
1315 // broke current fusion anyway. The reason why we must return failure
1316 // currently is that the callback function `newYieldValuesFn` will be
1317 // called after new init operand(s) has already been appended. It will
1318 // take more refactoring to make sure the init operands are added
1319 // consistently in the future. For more details, please refer to:
1320 // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1321 return failure();
1322 }
1323
1324 // c. calculate offsets and sizes info of all OpResults respectively based
1325 // on iteration Domain Tile
1326 SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1327 for (const auto &resultNumber : initNumberList) {
1328 if (resultNumber == sliceResultNumber) {
1329 offsetList.push_back(Elt: sliceOffset);
1330 sizesList.push_back(Elt: sliceSizes);
1331 } else {
1332 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1333 // infer result tile according to the iteration domain tile
1334 SmallVector<OpFoldResult> offset, sizes;
1335 if (failed(tilableOp.getResultTilePosition(
1336 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1337 offset, sizes))) {
1338 return failure();
1339 }
1340 offsetList.push_back(Elt: offset);
1341 sizesList.push_back(Elt: sizes);
1342 }
1343 }
1344
1345 // d. create `extract_slice` for `iter_args` for DPS operation if
1346 // necessary
1347 if (auto tiledDestStyleOp =
1348 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1349 rewriter.setInsertionPoint(tiledDestStyleOp);
1350 for (const auto &&[index, newRegionArg] :
1351 llvm::enumerate(First&: newRegionIterArgs)) {
1352 auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1353 loc, newRegionArg, offsetList[index], sizesList[index],
1354 SmallVector<OpFoldResult>(offsetList[index].size(),
1355 rewriter.getIndexAttr(1)));
1356 generatedSlices.push_back(Elt: destSlice);
1357 unsigned resultNumber = initNumberList[index];
1358 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1359 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1360 });
1361 }
1362 }
1363
1364 // e. prepare tiled offset and sizes for later `insert_slice` creation by
1365 // caller
1366 Block *block = rewriter.getInsertionPoint()->getBlock();
1367 rewriter.setInsertionPoint(block->getTerminator());
1368 for (const auto &&[index, resultNumber] : llvm::enumerate(First&: initNumberList)) {
1369 tiledResult.push_back(Elt: tiledOwner->getResult(idx: resultNumber));
1370 tiledOffset.emplace_back(Args&: offsetList[index]);
1371 tiledSizes.emplace_back(Args&: sizesList[index]);
1372 }
1373 return success();
1374 };
1375
1376 if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1377 newYieldValuesFn))) {
1378 return failure();
1379 }
1380 return generatedSlices;
1381}
1382
1383namespace {
1384
1385//===----------------------------------------------------------------------===//
1386// SliceTrackingListener
1387//===----------------------------------------------------------------------===//
1388
1389/// This class is a listener for tracking the insertion and removal of
1390/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1391/// fusion algorithm to apply cleanup patterns in between fusion steps.
1392class SliceTrackingListener : public RewriterBase::Listener {
1393public:
1394 explicit SliceTrackingListener(
1395 std::optional<FrozenRewritePatternSet> patterns);
1396 SliceTrackingListener() = default;
1397
1398 /// Adds the given list of operations to the worklist, and if present,
1399 /// applies the list of `patterns` to the newly added operations. This only
1400 /// processes the given operations and any newly inserted ones by the
1401 /// pattern set.
1402 LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1403
1404 /// Add to the new operation worklist if it is an extract_slice.
1405 void notifyOperationInserted(Operation *op,
1406 OpBuilder::InsertPoint previous) override;
1407
1408 /// Shared helper for operation removal from the worklist.
1409 void removeOp(Operation *op);
1410
1411 /// Remove the operation from the worklist.
1412 void notifyOperationErased(Operation *op) override;
1413
1414 /// Remove the operation from the worklist.
1415 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1416
1417 /// The worklist for this transformation keeps track of the slices to visit
1418 /// next for fusion.
1419 std::deque<tensor::ExtractSliceOp> worklist;
1420
1421private:
1422 /// Optional pattern set to apply when adding new operations to the
1423 /// worklist.
1424 std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1425};
1426
1427SliceTrackingListener::SliceTrackingListener(
1428 std::optional<FrozenRewritePatternSet> p) {
1429 patterns = std::move(p);
1430}
1431
1432LogicalResult
1433SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1434 for (Operation *op : ops) {
1435 if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1436 worklist.push_back(slice);
1437 }
1438
1439 if (!patterns)
1440 return success();
1441
1442 return applyOpPatternsGreedily(
1443 ops, patterns: patterns.value(),
1444 config: GreedyRewriteConfig().setListener(this).setStrictness(
1445 GreedyRewriteStrictness::ExistingAndNewOps));
1446}
1447
1448void SliceTrackingListener::notifyOperationInserted(
1449 Operation *op, OpBuilder::InsertPoint previous) {
1450 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1451 if (!slice)
1452 return;
1453 worklist.push_back(slice);
1454}
1455
1456// Scan the worklist for the given op and remove it if present. The
1457// expectation is for the worklist to be small and for removal to be
1458// relatively rare.
1459void SliceTrackingListener::removeOp(Operation *op) {
1460 if (!isa<tensor::ExtractSliceOp>(op))
1461 return;
1462 auto iter = worklist.begin();
1463 while (iter != worklist.end()) {
1464 if (*iter == op)
1465 break;
1466 iter++;
1467 }
1468 if (iter == worklist.end())
1469 return;
1470
1471 worklist.erase(iter);
1472}
1473
1474void SliceTrackingListener::notifyOperationErased(Operation *op) {
1475 removeOp(op);
1476}
1477
1478void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1479 ValueRange replacement) {
1480 removeOp(op);
1481}
1482
1483//===----------------------------------------------------------------------===//
1484// ReplacementListener
1485//===----------------------------------------------------------------------===//
1486
1487/// Listener that tracks updates replacements for values which can be mutated.
1488/// This listener runs on top of the existing listener for the rewriter,
1489/// to make sure external users can still run listeners.
1490class ReplacementListener : public RewriterBase::ForwardingListener {
1491public:
1492 ReplacementListener(DenseMap<Value, Value> &replacements,
1493 OpBuilder::Listener *listener)
1494 : ForwardingListener(listener), replacements(replacements) {}
1495
1496 void updateReplacementValues(ValueRange origValues,
1497 ValueRange replaceValues) {
1498 // This can probably be written better, but just iterates over the map
1499 // and the new replacements for now.
1500 for (auto &[key, val] : replacements) {
1501 for (auto [orig, replace] : llvm::zip_equal(t&: origValues, u&: replaceValues)) {
1502 if (val == orig) {
1503 val = replace;
1504 }
1505 }
1506 }
1507 }
1508
1509 void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1510 ForwardingListener::notifyOperationReplaced(op, newOp);
1511 updateReplacementValues(origValues: op->getResults(), replaceValues: newOp->getResults());
1512 }
1513
1514 void notifyOperationReplaced(Operation *op, ValueRange values) override {
1515 ForwardingListener::notifyOperationReplaced(op, replacement: values);
1516 updateReplacementValues(origValues: op->getResults(), replaceValues: values);
1517 }
1518
1519private:
1520 DenseMap<Value, Value> &replacements;
1521};
1522
1523} // namespace
1524
1525/// Implementation of tile consumer and fuse producer greedily.
1526FailureOr<scf::SCFTileAndFuseResult>
1527mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1528 RewriterBase &rewriter, TilingInterface consumer,
1529 const scf::SCFTileAndFuseOptions &options) {
1530 // This transformation is only valid for ops that return values (i.e. not
1531 // valid to use with operations that have memref operands).
1532 if (!consumer->getNumResults()) {
1533 return rewriter.notifyMatchFailure(
1534 consumer, "invalid pattern for op with no results");
1535 }
1536
1537 // 1. First tile the consumer.
1538 SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1539
1540 FailureOr<scf::SCFTilingResult> tilingResult =
1541 tileUsingSCF(rewriter, consumer, options.tilingOptions);
1542
1543 if (failed(Result: tilingResult))
1544 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1545 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1546
1547 DenseMap<Value, Value> replacements;
1548 for (auto [origVal, replacement] :
1549 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1550 replacements[origVal] = replacement;
1551 }
1552
1553 // If there are no loops generated, fusion is immaterial.
1554 auto &loops = tilingResult->loops;
1555 if (loops.empty()) {
1556 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1557 replacements};
1558 }
1559
1560 // Since the loop gets potentially replaced during fusion, we need to track
1561 // the mutation of replacement values. To do this, we attach a listener to
1562 // update the replacements as they happen.
1563 OpBuilder::Listener *previousListener = rewriter.getListener();
1564 auto resetListener =
1565 llvm::make_scope_exit(F: [&]() { rewriter.setListener(previousListener); });
1566 ReplacementListener replaceListener(replacements, previousListener);
1567 rewriter.setListener(&replaceListener);
1568
1569 // 2. Typically, the operands of the tiled operation are slices of the
1570 // operands of the untiled operation. These are expressed in IR using
1571 // `tensor.extract_slice` operations with source being the operands of
1572 // the untiled operation. Create a worklist of these
1573 // `tensor.extract_slice` operations. If the producers of the source of
1574 // the `tensor.extract_slice` can be tiled such that the tiled value is
1575 // generated in-place, that effectively tiles + fuses the operations.
1576 struct WorklistItem {
1577 tensor::ExtractSliceOp candidateSlice;
1578 SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1579 };
1580
1581 SliceTrackingListener sliceTracker =
1582 SliceTrackingListener(options.cleanupPatterns);
1583
1584 if (failed(
1585 sliceTracker.insertAndApplyPatterns(ops: tilingResult->generatedSlices))) {
1586 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1587 }
1588 OpBuilder::InsertionGuard g(rewriter);
1589 while (!sliceTracker.worklist.empty()) {
1590 auto candidateSlice = sliceTracker.worklist.front();
1591 sliceTracker.worklist.pop_front();
1592
1593 auto [fusableProducer, destinationInitArg] =
1594 getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1595 loops);
1596 if (!fusableProducer)
1597 continue;
1598
1599 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1600 options.fusionControlFn(candidateSlice, fusableProducer,
1601 destinationInitArg.has_value());
1602 if (!controlFnResult)
1603 continue;
1604
1605 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1606
1607 // The operands of the fused producer might themselved be slices of
1608 // values produced by operations that implement the `TilingInterface`.
1609 // Add these operations to the worklist.
1610 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1611 tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1612 loops);
1613 if (!fusedResult)
1614 continue;
1615
1616 SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1617
1618 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1619 // Reconstruct and yield all opResult of fusableProducerOp by default.
1620 // The caller can specific which one to yield by designating optional
1621 // argument named `yieldResultNumber` of
1622 // `yieldReplacementForFusedProducer`.
1623 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1624 FailureOr<SmallVector<Operation *>> newSlices =
1625 yieldReplacementForFusedProducer(rewriter,
1626 worklistItem.candidateSlice,
1627 fusedResult.value(), loops);
1628 if (failed(Result: newSlices)) {
1629 return rewriter.notifyMatchFailure(
1630 arg&: fusableProducerOp, msg: "failed to replacement value for this "
1631 "operation from within the tiled loop");
1632 }
1633 worklistCandidates.append(RHS: newSlices.value());
1634 for (auto [index, result] :
1635 llvm::enumerate(First: fusableProducerOp->getResults())) {
1636 replacements[result] = loops.front()->getResult(
1637 loops.front()->getNumResults() -
1638 fusableProducerOp->getNumResults() + index);
1639 }
1640 }
1641 if (Operation *tiledAndFusedOp =
1642 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1643 fusedProducers.insert(X: fusedResult->origProducer.getDefiningOp());
1644 tiledAndFusedOps.insert(X: tiledAndFusedOp);
1645 }
1646
1647 if (failed(Result: sliceTracker.insertAndApplyPatterns(ops: worklistCandidates))) {
1648 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1649 }
1650 }
1651
1652 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1653 replacements};
1654}
1655
1656//===----------------------------------------------------------------------===//
1657// tileAndFuseConsumerUsingSCF implementation.
1658//===----------------------------------------------------------------------===//
1659
1660/// A utility function that checks whether the only use of the result of a
1661/// tensor.insert_slice op is in a scf.yield op.
1662static LogicalResult
1663checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1664 Value result = candidateSliceOp.getResult();
1665 Value::use_range uses = result.getUses();
1666 if (!llvm::hasSingleElement(C&: uses)) {
1667 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1668 return failure();
1669 }
1670 OpOperand &operandUse = (*uses.begin());
1671 Operation *userOp = operandUse.getOwner();
1672 if (!isa<scf::YieldOp>(userOp)) {
1673 LLVM_DEBUG(llvm::dbgs()
1674 << "Expected scf.yield to be the only user, but got -> "
1675 << (*userOp));
1676 return failure();
1677 }
1678 if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1679 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1680 "be in the same block\n");
1681 return failure();
1682 }
1683 return success();
1684}
1685
1686/// An utility to get the first user of the given loopOp. If any of user stay
1687/// in different block of loopOp, return failure.
1688static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1689 if (!isa<LoopLikeOpInterface>(loopOp))
1690 return failure();
1691 Operation *firstUserOfLoop = nullptr;
1692 for (Operation *userOp : loopOp->getUsers()) {
1693 // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1694 // block with any other types of operation. Thus, just redirecting to its
1695 // parent `InParallelOp`. E.g.
1696 //
1697 // ```
1698 // %1 = scf.for {
1699 // ...
1700 // }
1701 // %2 = consumerOp ins(%1, ...)
1702 // scf.forall.in_parallel {
1703 // tensor.parallel_insert_slice %1
1704 // }
1705 // ```
1706 // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1707 // same block with `consumerOp`.
1708 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1709 userOp = userOp->getParentOfType<scf::InParallelOp>();
1710
1711 if (loopOp->getBlock() != userOp->getBlock())
1712 return failure();
1713
1714 if (!firstUserOfLoop || userOp->isBeforeInBlock(other: firstUserOfLoop))
1715 firstUserOfLoop = userOp;
1716 }
1717 return firstUserOfLoop;
1718}
1719
1720/// This utility currently checks whether the first userOp of loop is NOT
1721/// before the last defineOp of consumer operand. Because that we need to move
1722/// the whole loop structure right before the `firstUserOfLoop`. This utility
1723/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1724/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1725///
1726/// ```
1727/// %0 = scf.for() {
1728/// ...
1729/// }
1730/// ...
1731/// %1 = firstUserOfLoop(%0)
1732/// ...
1733/// %2 = lastDefOfConsumerOperand
1734/// ...
1735/// %3 = consumerOp(%2)
1736/// ```
1737///
1738/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1739/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1740/// a.k.a. use-def chain violation:
1741///
1742/// ```
1743/// %0:2 = scf.for() {
1744/// // use before define error
1745/// %3 = tiledConsumerOp(%2)
1746/// }
1747/// %1 = firstUserOfLoop(%0)
1748/// ...
1749/// %2 = lastDefOfConsumerOperand
1750/// ```
1751///
1752/// @param loopOp: loop operation
1753/// @param consumerOp: consumer operation
1754/// @param reorderOperations: the flag controls whether to reorder the
1755/// backward slice w.r.t. the defineOp of `consumerOp` operands.
1756/// @return: computed backward slice of consumerOp, but excluding those
1757/// already dominates `firstUserOfLoop`.
1758static FailureOr<llvm::SetVector<Operation *>>
1759checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
1760 bool reorderOperations) {
1761 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1762 if (failed(Result: firstUserOfLoop))
1763 return failure();
1764
1765 BackwardSliceOptions options;
1766 DominanceInfo dominanceInfo;
1767 options.inclusive = true;
1768 options.omitBlockArguments = true;
1769 bool includeLoopOp = false;
1770 options.filter = [&](Operation *op) {
1771 if (op == loopOp) {
1772 includeLoopOp = true;
1773 return false;
1774 }
1775 // Cut off the slice to not include any operation that already dominates
1776 // firstUserOfLoop.
1777 return !dominanceInfo.properlyDominates(a: op, b: *firstUserOfLoop);
1778 };
1779 llvm::SetVector<Operation *> slice;
1780 for (auto operand : consumerOp->getOperands()) {
1781 LogicalResult result = getBackwardSlice(root: operand, backwardSlice: &slice, options);
1782 assert(result.succeeded() && "expected a backward slice");
1783 (void)result;
1784 }
1785
1786 if (!slice.empty()) {
1787 // If consumerOp has one producer, which is also the user of loopOp.
1788 // E.g.
1789 // ```
1790 // %0 = %loopOp
1791 // %1 = consumerOp1 ins(%0)
1792 // %2 = consumerOp2 ins(%0, %1)
1793 // ```
1794 // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1795 // consumerOp1 has already been fused into loopOp before.
1796 if (includeLoopOp || !reorderOperations)
1797 return failure();
1798 }
1799
1800 return slice;
1801}
1802
1803/// Fetches the OpOperand of the first valid user (and use) of the value `val`
1804/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1805/// Returns failure otherwise.
1806static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1807 Operation *loopOp,
1808 unsigned resultNumber) {
1809 if (!isa<LoopLikeOpInterface>(loopOp))
1810 return failure();
1811 Value val = loopOp->getResult(idx: resultNumber);
1812 Block *loopBlock = loopOp->getBlock();
1813 for (OpOperand &opOperand : val.getUses()) {
1814 Operation *consumerOp = opOperand.getOwner();
1815 // Step 1. Check if the user is tilable.
1816 if (!isa<TilingInterface>(consumerOp) ||
1817 !isa<DestinationStyleOpInterface>(consumerOp)) {
1818 // TODO: We have to init result of consumer before scf.for, use
1819 // DestinationStyleOpInterface to get result shape from init for now.
1820 // Add support for other op such as op has InferTypeOpInterface.
1821 continue;
1822 }
1823 // Step 2. Check if user stay in the same block.
1824 if (loopBlock != consumerOp->getBlock())
1825 continue;
1826 // Step 3. Check if user has succeeding user. Otherwise, it usually
1827 // represents already tiled.
1828 if (consumerOp->use_empty())
1829 continue;
1830 // Step 4. Check assumption for loop with `reorderOperations` enabled.
1831 FailureOr<llvm::SetVector<Operation *>> slice =
1832 checkAssumptionForLoop(loopOp, consumerOp, reorderOperations: true);
1833 if (failed(Result: slice))
1834 continue;
1835 // Step 5. If backward sice is not empty, move them before
1836 // firstUserOfLoop.
1837 if (!slice->empty()) {
1838 mlir::topologicalSort(toSort: *slice);
1839 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1840 assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1841 for (auto op : *slice) {
1842 rewriter.moveOpBefore(op, existingOp: *firstUserOfLoop);
1843 }
1844 }
1845 return &opOperand;
1846 }
1847 return failure();
1848}
1849
1850/// Check that the loop is perfectly nested.
1851/// The loops are expected to be ordered from outer most to inner most.
1852/// For example:
1853/// ```
1854/// %0 = scf.for()
1855/// %1 = scf.for()
1856/// %2 = scf.for()
1857/// %3 = ...
1858/// yield %3
1859/// yield %2
1860/// yield %1
1861/// ```
1862/// Here loops should be [%0, %1].
1863static bool
1864isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
1865 assert(!loops.empty() && "unexpected empty loop nest");
1866 if (loops.size() == 1) {
1867 return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1868 }
1869 for (auto [outerLoop, innerLoop] :
1870 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1871 auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1872 auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1873 if (!outerFor || !innerFor) {
1874 return false;
1875 }
1876 auto outerBBArgs = outerFor.getRegionIterArgs();
1877 auto innerIterArgs = innerFor.getInitArgs();
1878 if (outerBBArgs.size() != innerIterArgs.size()) {
1879 return false;
1880 }
1881
1882 for (auto [outerBBArg, innerIterArg] :
1883 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1884 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1885 innerIterArg != outerBBArg) {
1886 return false;
1887 }
1888 }
1889
1890 ValueRange outerYields =
1891 cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1892 ValueRange innerResults = innerFor.getResults();
1893 if (outerYields.size() != innerResults.size()) {
1894 return false;
1895 }
1896 for (auto [outerYield, innerResult] :
1897 llvm::zip_equal(outerYields, innerResults)) {
1898 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1899 outerYield != innerResult) {
1900 return false;
1901 }
1902 }
1903 }
1904 return true;
1905}
1906
1907/// Fetch the untiled consumer of the outermost scf.for's result which is
1908/// yielded by a tensor.insert_slice from the innermost scf.for. This function
1909/// makes the following assumptions :
1910/// 1. tensor.insert_slice has scf.yield as its only user.
1911/// 2. scf.for's corresponding result has only one use.
1912/// 3. The `loops` passed in are perfectly nested `scf.for` operations.
1913static FailureOr<OpOperand *>
1914getUntiledConsumerFromSlice(RewriterBase &rewriter,
1915 tensor::InsertSliceOp candidateSliceOp,
1916 MutableArrayRef<LoopLikeOpInterface> loops) {
1917 assert(!loops.empty() && "unexpected loops to be empty");
1918 // 1. Expect slice to be part of the body of the inner most loop.
1919 Operation *containingOp = candidateSliceOp->getParentOp();
1920 if (containingOp != loops.back()) {
1921 return rewriter.notifyMatchFailure(
1922 candidateSliceOp,
1923 "expected slice to be within body of inner-most loop");
1924 }
1925
1926 // 2. Check that the loop is perfectly nested.
1927 if (!isPerfectlyNestedForLoops(loops)) {
1928 return rewriter.notifyMatchFailure(
1929 candidateSliceOp, "expected passed loops to be perfectly nested.");
1930 }
1931
1932 if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1933 return failure();
1934 Value sliceResult = candidateSliceOp.getResult();
1935
1936 // 3. Fetch the corresponding output.
1937 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1938 unsigned resultNumber = yieldOpOperand.getOperandNumber();
1939
1940 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
1941
1942 return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1943}
1944
1945/// Fetch the first untiled consumer of a scf.forall's result which is yielded
1946/// by a tensor.parallel_insert_slice.
1947static FailureOr<OpOperand *>
1948getUntiledConsumerFromSlice(RewriterBase &rewriter,
1949 tensor::ParallelInsertSliceOp candidateSliceOp,
1950 MutableArrayRef<LoopLikeOpInterface> loops) {
1951 assert(!loops.empty() && "unexpected loops to be empty");
1952 // 1. Check that the surrounding loop is a single scf.forall loop.
1953 if (loops.size() != 1) {
1954 return rewriter.notifyMatchFailure(
1955 candidateSliceOp, "expected single surrounding scf.forall");
1956 }
1957 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1958 if (!forallOp) {
1959 return rewriter.notifyMatchFailure(
1960 candidateSliceOp, "expected single surrounding scf.forall");
1961 }
1962
1963 // 2. Fetch the corresponding output
1964 Value sliceDest = candidateSliceOp.getDest();
1965 auto iterArg = dyn_cast<BlockArgument>(Val&: sliceDest);
1966 if (!iterArg)
1967 return failure();
1968 if (iterArg.getOwner()->getParentOp() != forallOp)
1969 return failure();
1970
1971 unsigned resultNumber =
1972 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1973 .getResultNumber();
1974
1975 return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
1976}
1977
1978/// A utility to fetch an untiled consumer of
1979/// tensor.insert_slice/tensor.parallel_insert_slice.
1980static FailureOr<OpOperand *>
1981getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
1982 MutableArrayRef<LoopLikeOpInterface> loops) {
1983 assert(!loops.empty() && "unexpected empty loops");
1984 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1985 return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
1986 } else if (auto parallelInsertSlice =
1987 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1988 return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
1989 } else {
1990 return failure();
1991 }
1992}
1993
1994/// Implementation of fusing consumer of a single slice by computing the
1995/// slice of the consumer in-place for scf loop.
1996FailureOr<scf::SCFFuseConsumerOfSliceResult>
1997mlir::scf::tileAndFuseConsumerOfSlice(
1998 RewriterBase &rewriter, Operation *candidateSliceOp,
1999 MutableArrayRef<LoopLikeOpInterface> loops) {
2000 // Return if `loops` is empty, return an error for now. Caller is expected
2001 // to handle this case.
2002 if (loops.empty()) {
2003 return candidateSliceOp->emitOpError(
2004 message: "cannot call tile and fuse consumer with an empty loop nest");
2005 }
2006 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2007 candidateSliceOp))
2008 return failure();
2009
2010 // 1. Get the consumer of scf.for for the result yielded by
2011 // tensor.insert_slice/parallel_insert_slice.
2012 FailureOr<OpOperand *> maybeConsumerOpOperand =
2013 getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
2014 if (failed(Result: maybeConsumerOpOperand)) {
2015 return rewriter.notifyMatchFailure(arg&: candidateSliceOp,
2016 msg: "could not fetch consumer to fuse");
2017 }
2018 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
2019 Operation *consumerOp = consumerOpOperand->getOwner();
2020 unsigned operandNumber = consumerOpOperand->getOperandNumber();
2021 unsigned resultNumber = 0;
2022 if (auto producerResult = dyn_cast<OpResult>(Val: consumerOpOperand->get())) {
2023 resultNumber = producerResult.getResultNumber();
2024 } else {
2025 return rewriter.notifyMatchFailure(
2026 arg&: consumerOp, msg: "consumer op's operand doesn't seem to be an OpResult");
2027 }
2028
2029 LoopLikeOpInterface outerMostLoop = loops.front();
2030 LoopLikeOpInterface innerMostLoop = loops.back();
2031
2032 // Check assumption for loop with `reorderOperations` disabled.
2033 if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2034 return rewriter.notifyMatchFailure(
2035 outerMostLoop, "the first user of loop should not dominate any define "
2036 "of consumer operand(s)");
2037 }
2038
2039 OpBuilder::InsertionGuard g(rewriter);
2040
2041 // 2. Check consumer is not using scf loop's output as init.
2042 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2043 if (!dstOp)
2044 return rewriter.notifyMatchFailure(arg&: consumerOp,
2045 msg: "consumer op is not DPS operation");
2046 SmallVector<Value> dpsInits =
2047 llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
2048 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2049 return rewriter.notifyMatchFailure(
2050 arg&: consumerOp,
2051 msg: "consumer op taking the result of scf.for as init is not supported");
2052 }
2053 SmallVector<Value> newInits = dpsInits;
2054
2055 Location loc = outerMostLoop->getLoc();
2056
2057 // 3. Move the whole loop structure right before firstUserOfLoop, the
2058 // dominance should be already ensured by `checkAssumptionForLoop`.
2059 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2060 if (failed(Result: firstUserOfLoop)) {
2061 return rewriter.notifyMatchFailure(
2062 outerMostLoop, "could not find the first user of outer most loop");
2063 }
2064 rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2065
2066 // 4. Set insertion point before terminator op of the loop and create a new
2067 // tensor.insert_slice. In the scf.for case this is a clone of the
2068 // candidateSliceOp whereas in the scf.forall case this is created from the
2069 // operands of tensor.parallel_insert_slice.
2070 tensor::InsertSliceOp clonedInsertSliceOp;
2071 if (auto sliceOp =
2072 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2073 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2074 rewriter.setInsertionPoint(newForallOp.getTerminator());
2075 clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
2076 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2077 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2078 } else {
2079 rewriter.setInsertionPoint(candidateSliceOp);
2080 clonedInsertSliceOp =
2081 cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
2082 }
2083
2084 // 5.a. Clone consumer op.
2085 auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(op&: *consumerOp));
2086
2087 // 5.b. Replace all uses of the loop result with the result of the cloned
2088 // tensor.insert_slice.
2089 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2090 rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2091 operandToReplace.set(clonedInsertSliceOp.getResult());
2092 });
2093
2094 // 6. Perform tiling of the cloned consumer and replace the operand at
2095 // `operandNumber` with the source of the cloned tensor.insert_slice op.
2096 auto ossSliceOp =
2097 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2098 FailureOr<TilingResult> tileAndFuseResult =
2099 tensor::replaceInsertSliceWithTiledConsumer(
2100 builder&: rewriter, sliceOp: ossSliceOp, consumerOp&: clonedConsumerOp->getOpOperand(operandNumber));
2101 if (failed(Result: tileAndFuseResult)) {
2102 return failure();
2103 }
2104 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2105 rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
2106 clonedInsertSliceOp.getSource());
2107
2108 // 7. Reconstruct [nested] loop with new inits.
2109 YieldTiledValuesFn newYieldValuesFn =
2110 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2111 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2112 SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
2113 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2114 OpBuilder::InsertionGuard g(innerRewriter);
2115 // 8. Set inner insertPoint right before tiled consumer op.
2116 innerRewriter.setInsertionPoint(tiledConsumerOp);
2117
2118 SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
2119 SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
2120 SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
2121
2122 // 9. Check all insert stride is 1.
2123 if (!llvm::all_of(Range&: strides, P: isOneInteger)) {
2124 return rewriter.notifyMatchFailure(
2125 arg&: candidateSliceOp, msg: "containingOp's result yield with stride");
2126 }
2127
2128 // 10. Try to get iter domain position from input position. Use
2129 // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2130 // domain may require index computation based on the result size. The
2131 // sizes and offsets should be the same either way, but using
2132 // tiledConsumerOp could lead to some chained unnecessary extra index
2133 // computation.
2134 SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2135 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2136 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2137 iterDomainSizes))) {
2138 return rewriter.notifyMatchFailure(
2139 clonedConsumerOp,
2140 "can't get iter domain position from input position");
2141 }
2142
2143 // 11. Try to fetch the offset and size for all results of the cloned
2144 // consumer. This would then be used to form the corresponding
2145 // tensor.insert_slice/parallel_insert_slice later.
2146 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2147 SmallVector<SmallVector<OpFoldResult>> resultOffsets(
2148 totalNumResultsOfConsumer);
2149 SmallVector<SmallVector<OpFoldResult>> resultSizes(
2150 totalNumResultsOfConsumer);
2151 for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2152 if (failed(tiledConsumerOp.getResultTilePosition(
2153 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2154 resultOffsets[idx], resultSizes[idx]))) {
2155 return rewriter.notifyMatchFailure(
2156 tiledConsumerOp,
2157 "can't get result domain position from iter domain position");
2158 }
2159 }
2160
2161 // 12. Create `extract_slice` for `iter_args` for DPS operation if
2162 // necessary.
2163 if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2164 tiledConsumerOp.getOperation())) {
2165 rewriter.setInsertionPoint(tiledDestStyleOp);
2166 for (const auto &&[index, newRegionArg] :
2167 llvm::enumerate(First&: newRegionIterArgs)) {
2168 auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2169 loc, newRegionArg, resultOffsets[index], resultSizes[index],
2170 SmallVector<OpFoldResult>(resultOffsets[index].size(),
2171 rewriter.getIndexAttr(1)));
2172 // Make a copy of index to avoid a capturing structured binding, which
2173 // is a C++20 extension.
2174 auto dstNumber = index;
2175 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2176 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2177 });
2178 }
2179 }
2180
2181 // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2182 // caller.
2183 Block *block = rewriter.getInsertionPoint()->getBlock();
2184 rewriter.setInsertionPoint(block->getTerminator());
2185 for (const auto &&[index, result] :
2186 llvm::enumerate(tiledConsumerOp->getResults())) {
2187 tiledResult.push_back(result);
2188 tiledOffset.emplace_back(resultOffsets[index]);
2189 tiledSizes.emplace_back(resultSizes[index]);
2190 }
2191 return success();
2192 };
2193 // 14. Add new inits to [nested] loops.
2194 if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
2195 newYieldValuesFn))) {
2196 return rewriter.notifyMatchFailure(tiledConsumerOp,
2197 "unable to add new inits to nest loop");
2198 }
2199
2200 // 15. Replace the result of scf loop and consumer op with new loop's
2201 // results.
2202
2203 for (auto &&[oldResult, newResult] :
2204 llvm::zip(consumerOp->getResults(),
2205 loops.front()->getResults().take_back(newInits.size()))) {
2206 rewriter.replaceAllUsesWith(oldResult, newResult);
2207 }
2208
2209 // 16. Need to erase the old scf loop and the cloned consumer op.
2210 rewriter.eraseOp(op: clonedConsumerOp);
2211
2212 return scf::SCFFuseConsumerOfSliceResult{
2213 .origConsumerOperand: consumerOpOperand,
2214 .tiledAndFusedConsumerOperand: &(tileAndFuseResult->tiledOps[0]->getOpOperand(idx: operandNumber)),
2215 .tiledOps: tileAndFuseResult->tiledOps};
2216}
2217
2218//===----------------------------------------------------------------------===//
2219// lowerToLoopsUsingSCFForOp implementation.
2220//===----------------------------------------------------------------------===//
2221
2222FailureOr<SmallVector<scf::ForOp>>
2223mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
2224 TilingInterface op) {
2225 // TODO: Handle cases where the op has results if needed.
2226 if (op->getNumResults() > 0) {
2227 return rewriter.notifyMatchFailure(
2228 op, "unable to lower to loops operations with return values");
2229 }
2230
2231 SmallVector<Range> domain = op.getIterationDomain(rewriter);
2232 SmallVector<Value> ivs;
2233 SmallVector<scf::ForOp> loops;
2234 Location loc = op.getLoc();
2235 for (auto loopRange : domain) {
2236 Value offsetVal =
2237 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2238 Value sizeVal =
2239 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2240 Value strideVal =
2241 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2242 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2243 strideVal, ValueRange{});
2244 loops.push_back(loop);
2245 ivs.push_back(loop.getInductionVar());
2246 rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2247 }
2248 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2249 return failure();
2250 }
2251 return loops;
2252}
2253

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp