1//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the tiling using TilingInterface.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/SCF/Utils/Utils.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Interfaces/DestinationStyleOpInterface.h"
25#include "mlir/Interfaces/TilingInterface.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/Debug.h"
28#include <optional>
29
30#define DEBUG_TYPE "tile-using-interface"
31
32using namespace mlir;
33
34scf::SCFTilingOptions &
35scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
36 assert(!tileSizeComputationFunction && "tile sizes already set");
37 auto tileSizes = llvm::to_vector(Range&: ts);
38 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
39 return tileSizes;
40 };
41 return *this;
42}
43
44/// Helper method to adjust the interchange vector to match the iteration
45/// domain.
46static SmallVector<int64_t>
47fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
48 size_t iterationDomainSize) {
49 SmallVector<int64_t> filledVector = llvm::to_vector(Range&: interchangeVector);
50 if (filledVector.size() < iterationDomainSize) {
51 auto range = llvm::seq<int64_t>(Begin: filledVector.size(), End: iterationDomainSize);
52 filledVector.append(in_start: range.begin(), in_end: range.end());
53 }
54 if (filledVector.size() > iterationDomainSize)
55 filledVector.resize(N: iterationDomainSize);
56 return filledVector;
57}
58
59//===----------------------------------------------------------------------===//
60// tileUsingSCF implementation.
61//===----------------------------------------------------------------------===//
62
63// Check if `stride` evenly divides the trip count `size - offset`.
64static bool tileDividesIterationDomain(Range loopRange) {
65 std::optional<int64_t> offsetAsInt = getConstantIntValue(ofr: loopRange.offset);
66 if (!offsetAsInt)
67 return false;
68 std::optional<int64_t> sizeAsInt = getConstantIntValue(ofr: loopRange.size);
69 if (!sizeAsInt)
70 return false;
71 std::optional<int64_t> strideAsInt = getConstantIntValue(ofr: loopRange.stride);
72 if (!strideAsInt)
73 return false;
74 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
75}
76
77/// Returns the bounded tile size given the current `iv`, `loopRange` and
78/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
79static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
80 Range loopRange, Value iv,
81 OpFoldResult tileSize) {
82 std::optional<int64_t> ts = getConstantIntValue(ofr: tileSize);
83 if (ts && ts.value() == 1)
84 return tileSize;
85
86 if (tileDividesIterationDomain(
87 loopRange: Range{.offset: loopRange.offset, .size: loopRange.size, .stride: tileSize}))
88 return tileSize;
89
90 // The tile size to use (to avoid out of bounds access) is minimum of
91 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
92 // loop.
93 AffineExpr s0, s1, d0;
94 bindDims(ctx: b.getContext(), exprs&: d0);
95 bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1);
96 AffineMap minMap = AffineMap::get(dimCount: 1, symbolCount: 2, results: {s0, s1 - d0}, context: b.getContext());
97 Value size = getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange.size);
98 return affine::makeComposedFoldedAffineMin(
99 b, loc, map: minMap, operands: SmallVector<OpFoldResult>{iv, tileSize, size});
100}
101
102/// A function that allows returning additional yielded values during
103/// `yieldTiledValuesAndReplace`.
104/// - `ivs` induction variable for the loop.
105/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
106/// - `tiledValues` the tiled values to return. Must be of same size as
107/// `newbbArgs`, each element of this array is inserted into the corresponding
108/// element in `newbbArgs`.
109/// - `resultOffsets` is of the same size as `tiledValues` and represents
110/// the offsets to use when inserting corresponding element from `tiledValues`
111/// into the element from `newBbArgs`.
112/// - `resultSizes` is of the same size as `tiledValues` and represents
113/// the size of the corresponding element from `tiledValues` inserted into
114/// the element from `newBbArgs`.
115/// In case the method needs to return `failure()` the method is expected
116/// to clean up any inserted operations.
117using YieldTiledValuesFn = std::function<LogicalResult(
118 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
119 SmallVector<Value> &tiledValues,
120 SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
121 SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
122
123/// Clones the operation and updates the destination if the operation
124/// implements the `DestinationStyleOpInterface`.
125static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
126 Operation *op,
127 ValueRange newDestArgs) {
128 Operation *clonedOp = rewriter.clone(op&: *op);
129 if (newDestArgs.empty())
130 return clonedOp;
131 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
132 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
133 return clonedOp;
134}
135
136/// Generate the tile-loop nest using `scf.for` operation.
137/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
138/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
139/// - `destinationTensors` are the init values to use for the outer most loop.
140/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
141/// most
142/// loop.
143/// - `loops` is an in-out parameter into which the generated loops are
144/// populated.
145static LogicalResult generateLoopNestUsingForOp(
146 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
147 ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
148 YieldTiledValuesFn yieldTiledValuesFn,
149 SmallVector<LoopLikeOpInterface> &loops) {
150 assert(!loopRanges.empty() && "unexpected empty loop ranges");
151 assert(loopRanges.size() == tileSizes.size() &&
152 "expected as many tile sizes as loop ranges");
153 OpBuilder::InsertionGuard guard(rewriter);
154 SmallVector<Value> ivs;
155
156 for (auto [loopRange, tileSize] : llvm::zip_equal(t&: loopRanges, u&: tileSizes)) {
157 // No loops if tile size is zero. Set offset and size to the loop
158 // offset and size.
159 if (isConstantIntValue(ofr: tileSize, value: 0))
160 continue;
161
162 Value lb = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.offset);
163 Value ub = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.size);
164 Value step = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: tileSize);
165 auto loop =
166 rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
167 [](OpBuilder &bodyBuilder, Location bodyLoc,
168 Value iv, ValueRange /*iterArgs*/) {});
169 loops.push_back(loop);
170 ivs.push_back(Elt: loop.getInductionVar());
171 rewriter.setInsertionPointToEnd(loop.getBody());
172 destinationTensors = loop.getRegionIterArgs();
173 }
174
175 SmallVector<Value> tiledResults;
176 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
177 if (failed(result: yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
178 tiledResults, resultOffsets, resultSizes))) {
179 return rewriter.notifyMatchFailure(
180 arg&: loc, msg: "failed to generate inner tile loop body");
181 }
182 if (loops.empty())
183 return success();
184
185 // 6. Yield all the results of the tiled operation.
186 SmallVector<Value> yieldedValues;
187 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
188 llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets,
189 args&: resultSizes)) {
190 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
191 rewriter.getIndexAttr(1));
192 auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
193 loc, tiledValue, destinationTensor, resultOffset, resultSize,
194 resultStride);
195 yieldedValues.push_back(Elt: insertSlice);
196 }
197 rewriter.create<scf::YieldOp>(loc, yieldedValues);
198
199 // Add the scf.yield operations for all the outer loops.
200 for (auto [outerLoop, innerLoop] :
201 llvm::zip_equal(MutableArrayRef(loops).drop_back(),
202 MutableArrayRef(loops).drop_front())) {
203 rewriter.setInsertionPointToEnd(
204 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
205 rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
206 }
207 return success();
208}
209
210/// Generate the tile-loop nest using `scf.forall` operation.
211/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
212/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
213/// - `destinationTensors` are the init values to use for the outer most loop.
214/// - `mappingVector` is the mapping attributes to use for loop construction.
215/// Can be empty.
216/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
217/// most
218/// loop.
219/// - `loops` is an in-out parameter into which the generated loops are
220/// populated.
221static LogicalResult generateLoopNestUsingForallOp(
222 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
223 ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
224 ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
225 SmallVector<LoopLikeOpInterface> &loops) {
226 SmallVector<OpFoldResult> lbs, ubs, steps;
227 assert(!loopRanges.empty() && "unexpected empty loop ranges");
228 assert(loopRanges.size() == tileSizes.size() &&
229 "expected as many tile sizes as loop ranges");
230 OpBuilder::InsertionGuard guard(rewriter);
231 SmallVector<OpFoldResult> offsets(loopRanges.size()),
232 sizes(loopRanges.size());
233
234 for (auto [tileSize, loopRange] : llvm::zip_equal(t&: tileSizes, u&: loopRanges)) {
235 if (isConstantIntValue(ofr: tileSize, value: 0))
236 continue;
237 lbs.push_back(Elt: loopRange.offset);
238 ubs.push_back(Elt: loopRange.size);
239 steps.push_back(Elt: tileSize);
240 }
241 assert(!lbs.empty() && "Expected at least one loop range");
242
243 std::optional<ArrayAttr> mappingAttr;
244 if (!mappingVector.empty())
245 mappingAttr = rewriter.getArrayAttr(mappingVector);
246
247 auto forallOp = rewriter.create<scf::ForallOp>(
248 loc, lbs, ubs, steps, destinationTensors, mappingAttr);
249 loops.push_back(forallOp);
250
251 rewriter.setInsertionPoint(forallOp.getTerminator());
252 destinationTensors = forallOp.getRegionOutArgs();
253
254 SmallVector<Value> tiledResults;
255 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
256 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
257 destinationTensors, tiledResults, resultOffsets,
258 resultSizes)))
259 return rewriter.notifyMatchFailure(arg&: loc, msg: "failed to generate loop body");
260
261 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
262 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
263 llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets,
264 args&: resultSizes)) {
265 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
266 rewriter.getIndexAttr(1));
267
268 rewriter.create<tensor::ParallelInsertSliceOp>(
269 loc, tiledValue, destinationTensor, resultOffset, resultSize,
270 resultStride);
271 }
272 return success();
273}
274
275/// Generate the tile-loop nest using the loop construct specifed in `options`.
276/// - `options`: Tiling options specified.
277/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
278/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
279/// - `destinationTensors` are the init values to use for the outer most loop.
280/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
281/// most
282/// loop.
283/// - `loops` is an in-out parameter into which the generated loops are
284/// populated.
285static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
286 const scf::SCFTilingOptions &options,
287 ArrayRef<Range> loopRanges,
288 ArrayRef<OpFoldResult> tileSizes,
289 ValueRange destinationTensors,
290 YieldTiledValuesFn tiledBodyFn,
291 SmallVector<LoopLikeOpInterface> &loops) {
292 // If the tile sizes are all zero, no loops are generated. Just call the
293 // callback function to handle untiled case.
294 if (llvm::all_of(Range&: tileSizes, P: isZeroIndex)) {
295 SmallVector<Value> tiledResults;
296 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
297 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
298 tiledResults, resultOffsets, resultSizes);
299 }
300 if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
301 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
302 destinationTensors, tiledBodyFn, loops);
303 }
304 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
305 return generateLoopNestUsingForallOp(
306 rewriter, loc, loopRanges, tileSizes, options.mappingVector,
307 destinationTensors, tiledBodyFn, loops);
308 }
309 return rewriter.notifyMatchFailure(arg&: loc, msg: "unhandled loop type");
310}
311
312/// Append the specified additional `newInitOperands` operands to the
313/// loops existing `init` operands (or similar), and replace `loopOp` with
314/// the new loop that has the additional init operands. The loop body of
315/// this loop is moved over to the new loop. `yieldTiledValuesFn`
316/// is called to get the new tiled values returned, and the offset
317/// and sizes at which the tiled value is inserted into the
318/// new region iter_args that correspond to the newly added init operands.
319template <typename LoopType>
320FailureOr<LoopLikeOpInterface>
321yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
322 ValueRange newInitOperands,
323 YieldTiledValuesFn yieldTiledValuesFn) {
324 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
325}
326
327/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
328template <>
329FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
330 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
331 YieldTiledValuesFn yieldTiledValuesFn) {
332 OpBuilder::InsertionGuard g(rewriter);
333 Location loc = loopOp.getLoc();
334 rewriter.setInsertionPoint(loopOp);
335
336 auto inits = llvm::to_vector(loopOp.getInitArgs());
337 inits.append(newInitOperands.begin(), newInitOperands.end());
338 auto newLoop = rewriter.create<scf::ForOp>(
339 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
340 inits, [](OpBuilder &, Location, Value, ValueRange) {});
341
342 // Move the loop body to the new op.
343 Block *loopBody = loopOp.getBody();
344 Block *newLoopBody = newLoop.getBody();
345 rewriter.mergeBlocks(
346 loopBody, newLoopBody,
347 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
348
349 auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
350 rewriter.setInsertionPoint(yieldOp);
351
352 SmallVector<Value> tiledValues;
353 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
354 ValueRange newRegionIterArgs =
355 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
356 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
357 newRegionIterArgs, tiledValues, resultOffsets,
358 resultSizes))) {
359 rewriter.eraseOp(newLoop);
360 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
361 }
362
363 SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
364 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
365 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
366 resultSizes)) {
367 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
368 rewriter.getIndexAttr(1));
369 Value insert = rewriter.create<tensor::InsertSliceOp>(
370 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
371 resultStride);
372 newYieldValues.push_back(insert);
373 }
374
375 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
376 rewriter.replaceOp(loopOp,
377 newLoop->getResults().take_front(loopOp.getNumResults()));
378 return cast<LoopLikeOpInterface>(newLoop.getOperation());
379}
380
381/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
382template <>
383FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
384 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
385 YieldTiledValuesFn yieldTiledValuesFn) {
386 OpBuilder::InsertionGuard g(rewriter);
387 Location loc = loopOp.getLoc();
388 rewriter.setInsertionPoint(loopOp);
389 auto inits = llvm::to_vector(loopOp.getOutputs());
390 inits.append(newInitOperands.begin(), newInitOperands.end());
391 auto newLoop = rewriter.create<scf::ForallOp>(
392 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
393 loopOp.getMixedStep(), inits, loopOp.getMapping(),
394 [](OpBuilder &, Location, ValueRange) {});
395
396 // Move the region of the current block to the newly created op.
397 Block *loopBody = loopOp.getBody();
398 Block *newLoopBody = newLoop.getBody();
399 rewriter.mergeBlocks(
400 loopBody, newLoopBody,
401 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
402
403 auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
404 rewriter.setInsertionPoint(terminator);
405 SmallVector<Value> tiledValues;
406 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
407 ValueRange regionIterArgs =
408 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
409 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
410 regionIterArgs, tiledValues, resultOffsets,
411 resultSizes))) {
412 rewriter.eraseOp(newLoop);
413 return rewriter.notifyMatchFailure(loopOp,
414 "failed to get yielded tiled values");
415 }
416
417 // Update the terminator.
418 rewriter.setInsertionPointToEnd(terminator.getBody());
419
420 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
421 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
422 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
423 rewriter.getIndexAttr(1));
424 rewriter.create<tensor::ParallelInsertSliceOp>(
425 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
426 resultStride);
427 }
428
429 rewriter.replaceOp(loopOp,
430 newLoop->getResults().take_front(loopOp.getNumResults()));
431 return cast<LoopLikeOpInterface>(newLoop.getOperation());
432}
433
434/// Implementation of `yieldTiledValuesAndReplaceLoop` for
435/// `LoopLikeOpInterface`, that just dispatches to the implementation for each
436/// supported loop type.
437FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
438 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
439 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
440 return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
441 loopLikeOp.getOperation())
442 .Case<scf::ForOp, scf::ForallOp>(
443 [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
444 return yieldTiledValuesAndReplaceLoop(
445 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
446 })
447 .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
448 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
449 });
450}
451
452/// Method to add new init values to a loop nest. Updates `loops` in-place with
453/// new loops that use the `newInitValues`.
454/// The outer-loops are updated to yield the new result values of the inner
455/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
456/// the additional values to yield form the innermost loop.
457static LogicalResult addInitOperandsToLoopNest(
458 RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
459 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
460 SmallVector<scf::ForOp> newLoops;
461 if (loops.empty())
462 return success();
463 OpBuilder::InsertionGuard g(rewriter);
464 rewriter.setInsertionPoint(loops.front());
465
466 SmallVector<Value> ivs;
467 for (auto &loop : loops.drop_back()) {
468 rewriter.setInsertionPoint(loop);
469
470 // if loops.size() > 1 we assume that scf.for is used for the loops.
471 auto forLoop = cast<scf::ForOp>(loop.getOperation());
472
473 // Create a new loop with the new init values for this loop.
474 SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
475 newInits.append(newInitValues.begin(), newInitValues.end());
476 auto newLoop = rewriter.create<scf::ForOp>(
477 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
478 forLoop.getStep(), newInits,
479 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
480
481 // Merge the body of the new loop with the body of the old loops.
482 SmallVector<Value> sourceBlockArgs;
483 sourceBlockArgs.push_back(newLoop.getInductionVar());
484 auto newRegionIterArgs = newLoop.getRegionIterArgs();
485 sourceBlockArgs.append(
486 newRegionIterArgs.begin(),
487 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
488 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
489 rewriter.replaceOp(
490 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
491 loop = newLoop;
492 ivs.push_back(newLoop.getInductionVar());
493 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
494 }
495
496 // Update the loop body of the innermost loop to get new yield values.
497 LoopLikeOpInterface innerMostLoop = loops.back();
498 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
499 yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
500 getNewTiledYieldsFn);
501
502 if (failed(newInnerMostLoop))
503 return innerMostLoop.emitOpError("failed to return additional yields");
504 loops.back() = newInnerMostLoop.value();
505
506 // Make all other loops except the innermost loops yield the values returned
507 // by the inner loop.
508 for (auto [outerLoop, innerLoop] :
509 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
510 // Again assume that all the outer loops are scf.for operations.
511 auto outerForLoop = cast<scf::ForOp>(outerLoop);
512 auto outerLoopYield =
513 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
514 SmallVector<Value> newYields =
515 llvm::to_vector(outerLoopYield.getOperands());
516 ValueRange additionalYields =
517 innerLoop->getResults().take_back(newInitValues.size());
518 newYields.append(additionalYields.begin(), additionalYields.end());
519 rewriter.setInsertionPoint(outerLoopYield);
520 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
521 }
522 return success();
523}
524
525/// Implementation of tiling transformation of `op` that implements the
526/// `TilingInterface` using `scf.for` to iterate over the tiles.
527FailureOr<scf::SCFTilingResult>
528mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
529 const scf::SCFTilingOptions &options) {
530 OpBuilder::InsertionGuard guard(rewriter);
531 rewriter.setInsertionPointAfter(op);
532
533 if (!options.tileSizeComputationFunction) {
534 return rewriter.notifyMatchFailure(
535 op, "missing tile size computation function");
536 }
537
538 // 1. Get the range of the loops that are represented by the operation.
539 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
540 size_t numLoops = iterationDomain.size();
541
542 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
543 // skips tiling a particular dimension. This convention is significantly
544 // simpler to handle instead of adjusting affine maps to account for missing
545 // dimensions.
546 SmallVector<OpFoldResult> tileSizes =
547 options.tileSizeComputationFunction(rewriter, op);
548 if (tileSizes.size() < iterationDomain.size()) {
549 auto zero = rewriter.getIndexAttr(0);
550 tileSizes.append(numLoops - tileSizes.size(), zero);
551 }
552
553 // 3. If there is an interchange specified, permute the iteration domain and
554 // the tile sizes.
555 SmallVector<int64_t> interchangeVector;
556 if (!options.interchangeVector.empty()) {
557 interchangeVector = fillInterchangeVector(interchangeVector: options.interchangeVector,
558 iterationDomainSize: iterationDomain.size());
559 }
560 if (!interchangeVector.empty()) {
561 if (!isPermutationVector(interchange: interchangeVector)) {
562 return rewriter.notifyMatchFailure(
563 op, "invalid intechange vector, not a permutation of the entire "
564 "iteration space");
565 }
566
567 applyPermutationToVector(inVec&: iterationDomain, permutation: interchangeVector);
568 applyPermutationToVector(inVec&: tileSizes, permutation: interchangeVector);
569 }
570
571 FailureOr<TilingResult> tilingResult;
572 // 4. Define the lambda function used later to generate the body of the
573 // innermost tiled loop.
574 YieldTiledValuesFn innerYieldTiledValuesFn =
575 [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
576 ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
577 SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
578 SmallVector<SmallVector<OpFoldResult>> &resultSizes)
579 -> LogicalResult {
580 // 4a. Compute the `offsets` and `sizes` to use for tiling.
581 SmallVector<OpFoldResult> offsets, sizes;
582 {
583 int materializedLoopNum = 0;
584 for (auto [tileSize, loopRange] :
585 llvm::zip_equal(tileSizes, iterationDomain)) {
586 if (isConstantIntValue(tileSize, 0)) {
587 offsets.push_back(loopRange.offset);
588 sizes.push_back(loopRange.size);
589 continue;
590 }
591 Value iv = ivs[materializedLoopNum++];
592 offsets.push_back(iv);
593 sizes.push_back(
594 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
595 }
596 }
597
598 // 4b. If interchange was provided, apply inverse of the interchange
599 // to get back the offsets/sizes in the order to be specified.
600 if (!interchangeVector.empty()) {
601 auto inversePermutation = invertPermutationVector(permutation: interchangeVector);
602 applyPermutationToVector(inVec&: offsets, permutation: inversePermutation);
603 applyPermutationToVector(inVec&: sizes, permutation: inversePermutation);
604 }
605
606 // 5. Generate the tiled implementation within the inner most loop.
607
608 // 5a. Clone the operation within the loop body.
609 auto clonedOp = cast<TilingInterface>(
610 cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
611
612 // 5b. Early return cloned op if tiling is not happening. We can not return
613 // the original op because it could lead to
614 // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
615 if (llvm::all_of(Range&: tileSizes, P: isZeroIndex)) {
616 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
617 tilingResult =
618 TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
619 return success();
620 }
621
622 // 5c. Tile the cloned operation.
623 tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
624 if (failed(result: tilingResult)) {
625 rewriter.eraseOp(op: clonedOp);
626 return op.emitOpError("faild to tile operation");
627 }
628
629 // 5d. Delete the cloned operation.
630 rewriter.eraseOp(op: clonedOp);
631
632 // 5e. Compute the offsets at which the result values are to be inserted
633 // back into its destinations.
634 for (auto [index, tiledValue] :
635 llvm::enumerate(First&: tilingResult->tiledValues)) {
636 tiledResults.push_back(Elt: tiledValue);
637 SmallVector<OpFoldResult> resultOffset, resultSize;
638 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
639 resultOffset, resultSize))) {
640 for (auto op : tilingResult->tiledOps) {
641 rewriter.eraseOp(op);
642 }
643 return rewriter.notifyMatchFailure(
644 op, "failed to get slice of result produced");
645 }
646 resultOffsets.emplace_back(Args: std::move(resultOffset));
647 resultSizes.emplace_back(Args: std::move(resultSize));
648 }
649
650 return success();
651 };
652
653 // 6. Find the destination tensors to use for the operation.
654 SmallVector<Value> destinationTensors;
655 if (failed(tensor::getOrCreateDestinations(b&: rewriter, loc: op.getLoc(), op: op,
656 result&: destinationTensors))) {
657 return rewriter.notifyMatchFailure(op,
658 "unable to create destination tensors");
659 }
660
661 // 7. Generate the tiled loops nest using the callback defined above.
662 SmallVector<LoopLikeOpInterface> loops;
663 if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
664 tileSizes, destinationTensors,
665 innerYieldTiledValuesFn, loops)))
666 return op.emitOpError("failed to generate tiling loops");
667 assert(succeeded(tilingResult) &&
668 "expected tiling result to be computed after loop generation");
669
670 // If loops are empty, the tiled op is used as the replacement for the untiled
671 // op.
672 if (loops.empty()) {
673 return scf::SCFTilingResult{tilingResult->tiledOps, loops,
674 tilingResult->tiledValues};
675 }
676
677 SmallVector<Value> replacements = llvm::map_to_vector(
678 loops.front()->getResults(), [](OpResult r) -> Value { return r; });
679 return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
680}
681
682FailureOr<scf::SCFReductionTilingResult>
683mlir::scf::tileReductionUsingScf(RewriterBase &b,
684 PartialReductionOpInterface op,
685 ArrayRef<OpFoldResult> tileSizes) {
686 Location loc = op.getLoc();
687 // Ops implementing PartialReductionOpInterface are expected to implement
688 // TilingInterface.
689 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
690 SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
691 auto tileSizesVector = llvm::to_vector(Range&: tileSizes);
692 if (tileSizesVector.size() < iterationDomain.size()) {
693 auto zero = b.getIndexAttr(0);
694 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
695 zero);
696 }
697 if (op->getNumResults() != 1)
698 return b.notifyMatchFailure(
699 op, "don't support ops with multiple results for now");
700 SmallVector<utils::IteratorType> iterators =
701 tilingInterfaceOp.getLoopIteratorTypes();
702
703 SmallVector<int> reductionDims;
704 for (auto [idx, iteratorType] :
705 llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
706 if (iteratorType == utils::IteratorType::reduction)
707 reductionDims.push_back(idx);
708 }
709
710 // 2. create the inital tensor value.
711 FailureOr<Operation *> identityTensor =
712 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
713 reductionDims);
714 if (failed(result: identityTensor))
715 return b.notifyMatchFailure(op,
716 "cannot create a tensor of identity value.");
717
718 // 3. Define the callback to use for generating the inner most tile loop body.
719 Operation *parallelOp = nullptr;
720 auto innerYieldTiledValuesFn =
721 [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
722 ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
723 SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
724 SmallVector<SmallVector<OpFoldResult>> &resultSizes)
725 -> LogicalResult {
726 SmallVector<OpFoldResult> offsets, sizes;
727 {
728 int materializedLoopNum = 0;
729 for (auto [tileSize, loopRange] :
730 llvm::zip_equal(tileSizesVector, iterationDomain)) {
731 if (isConstantIntValue(tileSize, 0)) {
732 offsets.push_back(loopRange.offset);
733 sizes.push_back(loopRange.size);
734 continue;
735 }
736 Value iv = ivs[materializedLoopNum++];
737 offsets.push_back(iv);
738 sizes.push_back(
739 getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
740 }
741 }
742
743 // 4a. Clone the operation.
744 auto clonedOp = cast<PartialReductionOpInterface>(
745 cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
746
747 // 4b. Tile the cloned operation.
748 parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
749 offsets, sizes, reductionDims);
750 // 4c. Delete the cloned operation.
751 b.eraseOp(op: clonedOp);
752
753 tiledResult.append(in_start: parallelOp->result_begin(), in_end: parallelOp->result_end());
754 // 4d. Compute the offsets and sizes needed to insert the result of the
755 // tiled value back into destination before yielding the destination.
756 SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
757 resultOffsets.emplace_back(Args: std::move(outOffsets));
758
759 SmallVector<OpFoldResult> outSizes;
760 for (size_t i = 0; i < offsets.size(); i++) {
761 outSizes.push_back(
762 Elt: tensor::getMixedSize(builder&: b, loc, value: parallelOp->getResult(idx: 0), dim: i));
763 }
764 resultSizes.emplace_back(Args: std::move(outSizes));
765 return success();
766 };
767
768 // 5. Generate the tiled implementation using the destination tensors.
769 SmallVector<Value> destinationTensors =
770 llvm::map_to_vector(C: identityTensor.value()->getResults(),
771 F: [](OpResult res) -> Value { return res; });
772
773 SmallVector<LoopLikeOpInterface> loops;
774 scf::SCFTilingOptions options;
775 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
776 if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
777 destinationTensors, innerYieldTiledValuesFn,
778 loops)))
779 return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
780
781 SmallVector<Value> replacements = llvm::map_to_vector(
782 loops.front()->getResults(), [](OpResult r) -> Value { return r; });
783
784 // 5. Apply the merge reduction to combine all the partial values.
785 b.setInsertionPointAfter(*loops.begin());
786 Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
787 b.replaceOp(op, mergeOp->getResults());
788
789 SCFReductionTilingResult results;
790 results.initialOp = *identityTensor;
791 results.loops = loops;
792 results.parallelTiledOp = parallelOp;
793 results.mergeOp = mergeOp;
794 return results;
795}
796
797//===----------------------------------------------------------------------===//
798// tileConsumerAndFuseProducersUsingSCF implementation.
799//===----------------------------------------------------------------------===//
800
801/// Return the untiled producer whose slice is used in a tiled consumer. The
802/// method traverses the tile loop nest (`loops`) if needed, and returns the
803/// `iter_args` of the outer most that is encountered. Traversing the iter_args
804/// indicates that this is a destination operand of the consumer. If there was
805/// no loop traversal needed, the second value of the returned tuple is empty.
806static std::tuple<OpResult, std::optional<OpOperand *>>
807getUntiledProducerFromSliceSource(OpOperand *source,
808 ArrayRef<LoopLikeOpInterface> loops) {
809 std::optional<OpOperand *> destinationIterArg;
810 auto loopIt = loops.rbegin();
811 while (auto iterArg = dyn_cast<BlockArgument>(Val: source->get())) {
812 auto loop = *loopIt;
813 if (iterArg.getOwner()->getParentOp() != loop)
814 break;
815 source = loop.getTiedLoopInit(iterArg);
816 loopIt++;
817 }
818 if (loopIt == loops.rend())
819 destinationIterArg = source;
820 return {dyn_cast<OpResult>(Val: source->get()), destinationIterArg};
821}
822
823/// Implementation of fusing producer of a single slice by computing the
824/// slice of the producer in-place.
825std::optional<scf::SCFFuseProducerOfSliceResult>
826mlir::scf::tileAndFuseProducerOfSlice(
827 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
828 MutableArrayRef<LoopLikeOpInterface> loops) {
829 // 1. Get the producer of the source (potentially walking through
830 // `iter_args` of nested `scf.for`)
831 auto [fusableProducer, destinationInitArg] =
832 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
833 loops);
834 if (!fusableProducer)
835 return std::nullopt;
836 unsigned resultNumber = fusableProducer.getResultNumber();
837
838 OpBuilder::InsertionGuard g(rewriter);
839 rewriter.setInsertionPoint(candidateSliceOp);
840
841 // 2. Clone the fused producer
842 // 2a. Compute the destination operands to use for the cloned operation.
843 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
844 Operation *fusableProducerOp = fusableProducer.getOwner();
845 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
846 failed(tensor::getOrCreateDestinations(
847 rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
848 origDestinationTensors)))
849 return std::nullopt;
850
851 clonedOpDestinationTensors = origDestinationTensors;
852 if (destinationInitArg &&
853 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
854 // 2b. If the producer is also destination style, then to maintain the
855 // destination passing style, update the destination of the producer to be
856 // the source of the slice.
857 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
858 }
859 // 2c. Clone the fused producer.
860 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
861 rewriter, op: fusableProducerOp, newDestArgs: clonedOpDestinationTensors);
862 // 2d. Update the source of the candidateSlice to be the cloned producer.
863 // Easier to just clone the slice with different source since replacements
864 // and DCE of cloned ops becomes easier
865 SmallVector<Value> candidateSliceOpOperands =
866 llvm::to_vector(candidateSliceOp->getOperands());
867 candidateSliceOpOperands[0] = clonedProducerOp->getResult(idx: resultNumber);
868 tensor::ExtractSliceOp clonedCandidateSliceOp =
869 mlir::clone(rewriter, candidateSliceOp,
870 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
871
872 // 3. Generate the tiled implementation of the producer of the source
873 FailureOr<TilingResult> tileAndFuseResult =
874 tensor::replaceExtractSliceWithTiledProducer(
875 rewriter, clonedCandidateSliceOp,
876 clonedProducerOp->getResult(idx: resultNumber));
877 if (failed(result: tileAndFuseResult))
878 return std::nullopt;
879 // Note: Do not delete the candidateSliceOp, since its passed in from the
880 // caller.
881 rewriter.replaceAllUsesWith(candidateSliceOp,
882 tileAndFuseResult->tiledValues[0]);
883 rewriter.eraseOp(op: clonedCandidateSliceOp);
884 rewriter.eraseOp(op: clonedProducerOp);
885
886 // 3. If the slice is for a destination operand, for example,
887 //
888 // ```mlir
889 // %0 = linalg.init
890 // %1 = linalg.fill .. outs(%0 : )
891 // %2 = scf.for .. iter_args(%arg0 = %1) {
892 // %3 = scf.for .. iter_args(%arg1 = %arg0) {
893 // %4 = tensor.extract_slice %arg1 [..]
894 // .. = linalg.matmul .. outs(%4 : )
895 // }
896 // }
897 // ```
898 //
899 // the IR is currently
900 //
901 // ```
902 // %0 = linalg.init
903 // %1 = linalg.fill
904 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
905 // %3 = scf.for .. iter_args(%arg1 = %arg0) {
906 // %4 = tensor.extract_slice %arg1[..]
907 // %5 = linalg.fill .. outs(%4 : )
908 // .. = linalg.matmul .. outs(%5 : )
909 // }
910 // }
911 // ```
912 //
913 // The untiled `linalg.fill` is still used as the `init_value` since it
914 // was originally a destination operand of the untiled `linalg.matmul`.
915 // When fusing an operand that is a destination operand, the iter_arg of
916 // the outer most loop should be changed to use the destination of the
917 // fused operation. With this the IR will be.
918 //
919 // ```
920 // %0 = linalg.init
921 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
922 // %2 = scf.for .. iter_args(%arg1 = %arg0) {
923 // %3 = tensor.extract_slice %arg1[..]
924 // %4 = linalg.fill .. outs(%3 : )
925 // .. = linalg.matmul .. outs(%4 : )
926 // }
927 // }
928 // ```
929 if (destinationInitArg &&
930 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
931 loops.front()
932 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
933 .set(origDestinationTensors[resultNumber]);
934 }
935 return scf::SCFFuseProducerOfSliceResult{fusableProducer,
936 tileAndFuseResult->tiledValues[0],
937 tileAndFuseResult->tiledOps};
938}
939
940/// Reconstruct the fused producer from within the tiled-and-fused code.
941LogicalResult mlir::scf::yieldReplacementForFusedProducer(
942 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
943 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
944 MutableArrayRef<LoopLikeOpInterface> loops) {
945 if (loops.empty())
946 return success();
947
948 OpResult fusableProducer = fusedProducerInfo.origProducer;
949 Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
950 FailureOr<Value> initValue = tensor::getOrCreateDestination(
951 b&: rewriter, loc: fusableProducer.getOwner()->getLoc(), opResult: fusableProducer);
952 if (succeeded(result: initValue)) {
953
954 YieldTiledValuesFn newYieldValuesFn =
955 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
956 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
957 SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
958 SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
959 -> LogicalResult {
960 OpBuilder::InsertionGuard g(innerRewriter);
961 if (auto tiledDestStyleOp =
962 tiledAndFusedProducer
963 .getDefiningOp<DestinationStyleOpInterface>()) {
964 rewriter.setInsertionPoint(tiledDestStyleOp);
965 Value newRegionArg = newRegionIterArgs.back();
966 auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
967 sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
968 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
969 unsigned resultNumber = fusableProducer.getResultNumber();
970 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
971 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
972 });
973 }
974 Block *block = rewriter.getInsertionPoint()->getBlock();
975 rewriter.setInsertionPoint(block->getTerminator());
976 tiledResult.push_back(Elt: fusedProducerInfo.tiledAndFusedProducer);
977 tiledOffset.emplace_back(sliceOp.getMixedOffsets());
978 tiledSizes.emplace_back(sliceOp.getMixedSizes());
979 return success();
980 };
981
982 return addInitOperandsToLoopNest(rewriter, loops,
983 SmallVector<Value>{initValue.value()},
984 newYieldValuesFn);
985 }
986 return success();
987}
988
989/// Implementation of tile consumer and fuse producer greedily.
990FailureOr<scf::SCFTileAndFuseResult>
991mlir::scf::tileConsumerAndFuseProducersUsingSCF(
992 RewriterBase &rewriter, TilingInterface consumer,
993 const scf::SCFTileAndFuseOptions &options) {
994 // This transformation is only valid for ops that return values (i.e. not
995 // valid to use with operations that have memref operands).
996 if (!consumer->getNumResults()) {
997 return rewriter.notifyMatchFailure(
998 consumer, "invalid pattern for op with no results");
999 }
1000
1001 // 1. First tile the consumer.
1002 SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1003 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1004
1005 FailureOr<scf::SCFTilingResult> tilingResult =
1006 tileUsingSCF(rewriter, consumer, options.tilingOptions);
1007
1008 if (failed(result: tilingResult))
1009 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1010 for (auto *tiledOp : tilingResult->tiledOps)
1011 tiledAndFusedOps.insert(tiledOp);
1012
1013 // If there are no loops generated, fusion is immaterial.
1014 auto &loops = tilingResult->loops;
1015 if (loops.empty()) {
1016 DenseMap<Value, Value> replacements;
1017 for (auto [origVal, replacement] :
1018 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1019 replacements[origVal] = replacement;
1020 }
1021 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1022 replacements};
1023 }
1024
1025 // To keep track of replacements for now just record the map from the original
1026 // untiled value to the result number of the for loop. Since the loop gets
1027 // potentially replaced during fusion, keeping the value directly wont work.
1028 DenseMap<Value, size_t> origValToResultNumber;
1029 for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1030 origValToResultNumber[result] = index;
1031 }
1032
1033 // 2. Typically, the operands of the tiled operation are slices of the
1034 // operands of the untiled operation. These are expressed in IR using
1035 // `tensor.extract_slice` operations with source being the operands of the
1036 // untiled operation. Create a worklist of these `tensor.extract_slice`
1037 // operations. If the producers of the source of the `tensor.extract_slice`
1038 // can be tiled such that the tiled value is generated in-place, that
1039 // effectively tiles + fuses the operations.
1040 auto addCandidateSlices = [](Operation *fusedOp,
1041 std::deque<tensor::ExtractSliceOp> &candidates) {
1042 for (Value operand : fusedOp->getOperands())
1043 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
1044 candidates.push_back(sliceOp);
1045 };
1046
1047 std::deque<tensor::ExtractSliceOp> candidates;
1048 addCandidateSlices(tiledAndFusedOps.back(), candidates);
1049 OpBuilder::InsertionGuard g(rewriter);
1050 while (!candidates.empty()) {
1051 // Traverse the slices in BFS fashion.
1052 tensor::ExtractSliceOp candidateSliceOp = candidates.front();
1053 candidates.pop_front();
1054
1055 // Find the original producer of the slice.
1056 auto [fusableProducer, destinationInitArg] =
1057 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1058 loops);
1059 if (!fusableProducer)
1060 continue;
1061
1062 auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
1063 candidateSliceOp, fusableProducer, destinationInitArg.has_value());
1064 if (!fuseSlice)
1065 continue;
1066
1067 // The operands of the fused producer might themselved be slices of
1068 // values produced by operations that implement the `TilingInterface`.
1069 // Add these operations to the worklist.
1070 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1071 tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
1072 if (!fusedResult)
1073 continue;
1074
1075 if (yieldReplacement) {
1076 if (failed(yieldReplacementForFusedProducer(
1077 rewriter, candidateSliceOp, fusedResult.value(), loops))) {
1078 return rewriter.notifyMatchFailure(
1079 fusableProducer.getOwner(), "failed to replacement value for this "
1080 "oepration from within the tiled loop");
1081 }
1082 origValToResultNumber[fusableProducer] =
1083 loops.front()->getNumResults() - 1;
1084 }
1085
1086 if (Operation *tiledAndFusedOp =
1087 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1088 fusedProducers.insert(X: fusedResult->origProducer.getDefiningOp());
1089 tiledAndFusedOps.insert(X: tiledAndFusedOp);
1090 addCandidateSlices(tiledAndFusedOp, candidates);
1091 }
1092 }
1093
1094 DenseMap<Value, Value> replacements;
1095 for (auto [origVal, resultNumber] : origValToResultNumber) {
1096 replacements[origVal] = loops.front()->getResult(resultNumber);
1097 }
1098
1099 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1100 replacements};
1101}
1102
1103//===----------------------------------------------------------------------===//
1104// lowerToLoopsUsingSCFForOp implementation.
1105//===----------------------------------------------------------------------===//
1106
1107FailureOr<SmallVector<scf::ForOp>>
1108mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
1109 TilingInterface op) {
1110 // TODO: Handle cases where the op has results if needed.
1111 if (op->getNumResults() > 0) {
1112 return rewriter.notifyMatchFailure(
1113 op, "unable to lower to loops operations with return values");
1114 }
1115
1116 SmallVector<Range> domain = op.getIterationDomain(rewriter);
1117 SmallVector<Value> ivs;
1118 SmallVector<scf::ForOp> loops;
1119 Location loc = op.getLoc();
1120 for (auto loopRange : domain) {
1121 Value offsetVal =
1122 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
1123 Value sizeVal =
1124 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
1125 Value strideVal =
1126 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
1127 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
1128 strideVal, ValueRange{});
1129 loops.push_back(loop);
1130 ivs.push_back(loop.getInductionVar());
1131 rewriter.setInsertionPoint(loop.getBody()->getTerminator());
1132 }
1133 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
1134 return failure();
1135 }
1136 return loops;
1137}
1138

source code of mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp