1 | //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the tiling using TilingInterface. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" |
14 | |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
20 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
22 | #include "mlir/IR/Matchers.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
25 | #include "mlir/Interfaces/TilingInterface.h" |
26 | #include "llvm/ADT/TypeSwitch.h" |
27 | #include "llvm/Support/Debug.h" |
28 | #include <optional> |
29 | |
30 | #define DEBUG_TYPE "tile-using-interface" |
31 | |
32 | using namespace mlir; |
33 | |
34 | scf::SCFTilingOptions & |
35 | scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { |
36 | assert(!tileSizeComputationFunction && "tile sizes already set" ); |
37 | auto tileSizes = llvm::to_vector(Range&: ts); |
38 | tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { |
39 | return tileSizes; |
40 | }; |
41 | return *this; |
42 | } |
43 | |
44 | /// Helper method to adjust the interchange vector to match the iteration |
45 | /// domain. |
46 | static SmallVector<int64_t> |
47 | fillInterchangeVector(ArrayRef<int64_t> interchangeVector, |
48 | size_t iterationDomainSize) { |
49 | SmallVector<int64_t> filledVector = llvm::to_vector(Range&: interchangeVector); |
50 | if (filledVector.size() < iterationDomainSize) { |
51 | auto range = llvm::seq<int64_t>(Begin: filledVector.size(), End: iterationDomainSize); |
52 | filledVector.append(in_start: range.begin(), in_end: range.end()); |
53 | } |
54 | if (filledVector.size() > iterationDomainSize) |
55 | filledVector.resize(N: iterationDomainSize); |
56 | return filledVector; |
57 | } |
58 | |
59 | //===----------------------------------------------------------------------===// |
60 | // tileUsingSCF implementation. |
61 | //===----------------------------------------------------------------------===// |
62 | |
63 | // Check if `stride` evenly divides the trip count `size - offset`. |
64 | static bool tileDividesIterationDomain(Range loopRange) { |
65 | std::optional<int64_t> offsetAsInt = getConstantIntValue(ofr: loopRange.offset); |
66 | if (!offsetAsInt) |
67 | return false; |
68 | std::optional<int64_t> sizeAsInt = getConstantIntValue(ofr: loopRange.size); |
69 | if (!sizeAsInt) |
70 | return false; |
71 | std::optional<int64_t> strideAsInt = getConstantIntValue(ofr: loopRange.stride); |
72 | if (!strideAsInt) |
73 | return false; |
74 | return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); |
75 | } |
76 | |
77 | /// Returns the bounded tile size given the current `iv`, `loopRange` and |
78 | /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. |
79 | static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, |
80 | Range loopRange, Value iv, |
81 | OpFoldResult tileSize) { |
82 | std::optional<int64_t> ts = getConstantIntValue(ofr: tileSize); |
83 | if (ts && ts.value() == 1) |
84 | return tileSize; |
85 | |
86 | if (tileDividesIterationDomain( |
87 | loopRange: Range{.offset: loopRange.offset, .size: loopRange.size, .stride: tileSize})) |
88 | return tileSize; |
89 | |
90 | // The tile size to use (to avoid out of bounds access) is minimum of |
91 | // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled |
92 | // loop. |
93 | AffineExpr s0, s1, d0; |
94 | bindDims(ctx: b.getContext(), exprs&: d0); |
95 | bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1); |
96 | AffineMap minMap = AffineMap::get(dimCount: 1, symbolCount: 2, results: {s0, s1 - d0}, context: b.getContext()); |
97 | Value size = getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange.size); |
98 | return affine::makeComposedFoldedAffineMin( |
99 | b, loc, map: minMap, operands: SmallVector<OpFoldResult>{iv, tileSize, size}); |
100 | } |
101 | |
102 | /// A function that allows returning additional yielded values during |
103 | /// `yieldTiledValuesAndReplace`. |
104 | /// - `ivs` induction variable for the loop. |
105 | /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. |
106 | /// - `tiledValues` the tiled values to return. Must be of same size as |
107 | /// `newbbArgs`, each element of this array is inserted into the corresponding |
108 | /// element in `newbbArgs`. |
109 | /// - `resultOffsets` is of the same size as `tiledValues` and represents |
110 | /// the offsets to use when inserting corresponding element from `tiledValues` |
111 | /// into the element from `newBbArgs`. |
112 | /// - `resultSizes` is of the same size as `tiledValues` and represents |
113 | /// the size of the corresponding element from `tiledValues` inserted into |
114 | /// the element from `newBbArgs`. |
115 | /// In case the method needs to return `failure()` the method is expected |
116 | /// to clean up any inserted operations. |
117 | using YieldTiledValuesFn = std::function<LogicalResult( |
118 | RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, |
119 | SmallVector<Value> &tiledValues, |
120 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
121 | SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; |
122 | |
123 | /// Clones the operation and updates the destination if the operation |
124 | /// implements the `DestinationStyleOpInterface`. |
125 | static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, |
126 | Operation *op, |
127 | ValueRange newDestArgs) { |
128 | Operation *clonedOp = rewriter.clone(op&: *op); |
129 | if (newDestArgs.empty()) |
130 | return clonedOp; |
131 | if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) |
132 | destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); |
133 | return clonedOp; |
134 | } |
135 | |
136 | /// Generate the tile-loop nest using `scf.for` operation. |
137 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
138 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
139 | /// - `destinationTensors` are the init values to use for the outer most loop. |
140 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
141 | /// most |
142 | /// loop. |
143 | /// - `loops` is an in-out parameter into which the generated loops are |
144 | /// populated. |
145 | static LogicalResult generateLoopNestUsingForOp( |
146 | RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
147 | ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, |
148 | YieldTiledValuesFn yieldTiledValuesFn, |
149 | SmallVector<LoopLikeOpInterface> &loops) { |
150 | assert(!loopRanges.empty() && "unexpected empty loop ranges" ); |
151 | assert(loopRanges.size() == tileSizes.size() && |
152 | "expected as many tile sizes as loop ranges" ); |
153 | OpBuilder::InsertionGuard guard(rewriter); |
154 | SmallVector<Value> ivs; |
155 | |
156 | for (auto [loopRange, tileSize] : llvm::zip_equal(t&: loopRanges, u&: tileSizes)) { |
157 | // No loops if tile size is zero. Set offset and size to the loop |
158 | // offset and size. |
159 | if (isConstantIntValue(ofr: tileSize, value: 0)) |
160 | continue; |
161 | |
162 | Value lb = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.offset); |
163 | Value ub = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.size); |
164 | Value step = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: tileSize); |
165 | auto loop = |
166 | rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, |
167 | [](OpBuilder &bodyBuilder, Location bodyLoc, |
168 | Value iv, ValueRange /*iterArgs*/) {}); |
169 | loops.push_back(loop); |
170 | ivs.push_back(Elt: loop.getInductionVar()); |
171 | rewriter.setInsertionPointToEnd(loop.getBody()); |
172 | destinationTensors = loop.getRegionIterArgs(); |
173 | } |
174 | |
175 | SmallVector<Value> tiledResults; |
176 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
177 | if (failed(result: yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, |
178 | tiledResults, resultOffsets, resultSizes))) { |
179 | return rewriter.notifyMatchFailure( |
180 | arg&: loc, msg: "failed to generate inner tile loop body" ); |
181 | } |
182 | if (loops.empty()) |
183 | return success(); |
184 | |
185 | // 6. Yield all the results of the tiled operation. |
186 | SmallVector<Value> yieldedValues; |
187 | for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : |
188 | llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets, |
189 | args&: resultSizes)) { |
190 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
191 | rewriter.getIndexAttr(1)); |
192 | auto insertSlice = rewriter.create<tensor::InsertSliceOp>( |
193 | loc, tiledValue, destinationTensor, resultOffset, resultSize, |
194 | resultStride); |
195 | yieldedValues.push_back(Elt: insertSlice); |
196 | } |
197 | rewriter.create<scf::YieldOp>(loc, yieldedValues); |
198 | |
199 | // Add the scf.yield operations for all the outer loops. |
200 | for (auto [outerLoop, innerLoop] : |
201 | llvm::zip_equal(MutableArrayRef(loops).drop_back(), |
202 | MutableArrayRef(loops).drop_front())) { |
203 | rewriter.setInsertionPointToEnd( |
204 | cast<scf::ForOp>(outerLoop.getOperation()).getBody()); |
205 | rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); |
206 | } |
207 | return success(); |
208 | } |
209 | |
210 | /// Generate the tile-loop nest using `scf.forall` operation. |
211 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
212 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
213 | /// - `destinationTensors` are the init values to use for the outer most loop. |
214 | /// - `mappingVector` is the mapping attributes to use for loop construction. |
215 | /// Can be empty. |
216 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
217 | /// most |
218 | /// loop. |
219 | /// - `loops` is an in-out parameter into which the generated loops are |
220 | /// populated. |
221 | static LogicalResult generateLoopNestUsingForallOp( |
222 | RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
223 | ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector, |
224 | ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, |
225 | SmallVector<LoopLikeOpInterface> &loops) { |
226 | SmallVector<OpFoldResult> lbs, ubs, steps; |
227 | assert(!loopRanges.empty() && "unexpected empty loop ranges" ); |
228 | assert(loopRanges.size() == tileSizes.size() && |
229 | "expected as many tile sizes as loop ranges" ); |
230 | OpBuilder::InsertionGuard guard(rewriter); |
231 | SmallVector<OpFoldResult> offsets(loopRanges.size()), |
232 | sizes(loopRanges.size()); |
233 | |
234 | for (auto [tileSize, loopRange] : llvm::zip_equal(t&: tileSizes, u&: loopRanges)) { |
235 | if (isConstantIntValue(ofr: tileSize, value: 0)) |
236 | continue; |
237 | lbs.push_back(Elt: loopRange.offset); |
238 | ubs.push_back(Elt: loopRange.size); |
239 | steps.push_back(Elt: tileSize); |
240 | } |
241 | assert(!lbs.empty() && "Expected at least one loop range" ); |
242 | |
243 | std::optional<ArrayAttr> mappingAttr; |
244 | if (!mappingVector.empty()) |
245 | mappingAttr = rewriter.getArrayAttr(mappingVector); |
246 | |
247 | auto forallOp = rewriter.create<scf::ForallOp>( |
248 | loc, lbs, ubs, steps, destinationTensors, mappingAttr); |
249 | loops.push_back(forallOp); |
250 | |
251 | rewriter.setInsertionPoint(forallOp.getTerminator()); |
252 | destinationTensors = forallOp.getRegionOutArgs(); |
253 | |
254 | SmallVector<Value> tiledResults; |
255 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
256 | if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), |
257 | destinationTensors, tiledResults, resultOffsets, |
258 | resultSizes))) |
259 | return rewriter.notifyMatchFailure(arg&: loc, msg: "failed to generate loop body" ); |
260 | |
261 | rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); |
262 | for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : |
263 | llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets, |
264 | args&: resultSizes)) { |
265 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
266 | rewriter.getIndexAttr(1)); |
267 | |
268 | rewriter.create<tensor::ParallelInsertSliceOp>( |
269 | loc, tiledValue, destinationTensor, resultOffset, resultSize, |
270 | resultStride); |
271 | } |
272 | return success(); |
273 | } |
274 | |
275 | /// Generate the tile-loop nest using the loop construct specifed in `options`. |
276 | /// - `options`: Tiling options specified. |
277 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
278 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
279 | /// - `destinationTensors` are the init values to use for the outer most loop. |
280 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
281 | /// most |
282 | /// loop. |
283 | /// - `loops` is an in-out parameter into which the generated loops are |
284 | /// populated. |
285 | static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, |
286 | const scf::SCFTilingOptions &options, |
287 | ArrayRef<Range> loopRanges, |
288 | ArrayRef<OpFoldResult> tileSizes, |
289 | ValueRange destinationTensors, |
290 | YieldTiledValuesFn tiledBodyFn, |
291 | SmallVector<LoopLikeOpInterface> &loops) { |
292 | // If the tile sizes are all zero, no loops are generated. Just call the |
293 | // callback function to handle untiled case. |
294 | if (llvm::all_of(Range&: tileSizes, P: isZeroIndex)) { |
295 | SmallVector<Value> tiledResults; |
296 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
297 | return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, |
298 | tiledResults, resultOffsets, resultSizes); |
299 | } |
300 | if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { |
301 | return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, |
302 | destinationTensors, tiledBodyFn, loops); |
303 | } |
304 | if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { |
305 | return generateLoopNestUsingForallOp( |
306 | rewriter, loc, loopRanges, tileSizes, options.mappingVector, |
307 | destinationTensors, tiledBodyFn, loops); |
308 | } |
309 | return rewriter.notifyMatchFailure(arg&: loc, msg: "unhandled loop type" ); |
310 | } |
311 | |
312 | /// Append the specified additional `newInitOperands` operands to the |
313 | /// loops existing `init` operands (or similar), and replace `loopOp` with |
314 | /// the new loop that has the additional init operands. The loop body of |
315 | /// this loop is moved over to the new loop. `yieldTiledValuesFn` |
316 | /// is called to get the new tiled values returned, and the offset |
317 | /// and sizes at which the tiled value is inserted into the |
318 | /// new region iter_args that correspond to the newly added init operands. |
319 | template <typename LoopType> |
320 | FailureOr<LoopLikeOpInterface> |
321 | yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, |
322 | ValueRange newInitOperands, |
323 | YieldTiledValuesFn yieldTiledValuesFn) { |
324 | return rewriter.notifyMatchFailure(loopOp, "unhandled loop type" ); |
325 | } |
326 | |
327 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. |
328 | template <> |
329 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( |
330 | scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
331 | YieldTiledValuesFn yieldTiledValuesFn) { |
332 | OpBuilder::InsertionGuard g(rewriter); |
333 | Location loc = loopOp.getLoc(); |
334 | rewriter.setInsertionPoint(loopOp); |
335 | |
336 | auto inits = llvm::to_vector(loopOp.getInitArgs()); |
337 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
338 | auto newLoop = rewriter.create<scf::ForOp>( |
339 | loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), |
340 | inits, [](OpBuilder &, Location, Value, ValueRange) {}); |
341 | |
342 | // Move the loop body to the new op. |
343 | Block *loopBody = loopOp.getBody(); |
344 | Block *newLoopBody = newLoop.getBody(); |
345 | rewriter.mergeBlocks( |
346 | loopBody, newLoopBody, |
347 | newLoopBody->getArguments().take_front(loopBody->getNumArguments())); |
348 | |
349 | auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); |
350 | rewriter.setInsertionPoint(yieldOp); |
351 | |
352 | SmallVector<Value> tiledValues; |
353 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
354 | ValueRange newRegionIterArgs = |
355 | newLoop.getRegionIterArgs().take_back(newInitOperands.size()); |
356 | if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), |
357 | newRegionIterArgs, tiledValues, resultOffsets, |
358 | resultSizes))) { |
359 | rewriter.eraseOp(newLoop); |
360 | return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values" ); |
361 | } |
362 | |
363 | SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); |
364 | for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : |
365 | llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, |
366 | resultSizes)) { |
367 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
368 | rewriter.getIndexAttr(1)); |
369 | Value insert = rewriter.create<tensor::InsertSliceOp>( |
370 | yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, |
371 | resultStride); |
372 | newYieldValues.push_back(insert); |
373 | } |
374 | |
375 | rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); |
376 | rewriter.replaceOp(loopOp, |
377 | newLoop->getResults().take_front(loopOp.getNumResults())); |
378 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
379 | } |
380 | |
381 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` |
382 | template <> |
383 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( |
384 | scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
385 | YieldTiledValuesFn yieldTiledValuesFn) { |
386 | OpBuilder::InsertionGuard g(rewriter); |
387 | Location loc = loopOp.getLoc(); |
388 | rewriter.setInsertionPoint(loopOp); |
389 | auto inits = llvm::to_vector(loopOp.getOutputs()); |
390 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
391 | auto newLoop = rewriter.create<scf::ForallOp>( |
392 | loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), |
393 | loopOp.getMixedStep(), inits, loopOp.getMapping(), |
394 | [](OpBuilder &, Location, ValueRange) {}); |
395 | |
396 | // Move the region of the current block to the newly created op. |
397 | Block *loopBody = loopOp.getBody(); |
398 | Block *newLoopBody = newLoop.getBody(); |
399 | rewriter.mergeBlocks( |
400 | loopBody, newLoopBody, |
401 | newLoopBody->getArguments().take_front(loopBody->getNumArguments())); |
402 | |
403 | auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); |
404 | rewriter.setInsertionPoint(terminator); |
405 | SmallVector<Value> tiledValues; |
406 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
407 | ValueRange regionIterArgs = |
408 | newLoop.getRegionIterArgs().take_back(newInitOperands.size()); |
409 | if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), |
410 | regionIterArgs, tiledValues, resultOffsets, |
411 | resultSizes))) { |
412 | rewriter.eraseOp(newLoop); |
413 | return rewriter.notifyMatchFailure(loopOp, |
414 | "failed to get yielded tiled values" ); |
415 | } |
416 | |
417 | // Update the terminator. |
418 | rewriter.setInsertionPointToEnd(terminator.getBody()); |
419 | |
420 | for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( |
421 | tiledValues, regionIterArgs, resultOffsets, resultSizes)) { |
422 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
423 | rewriter.getIndexAttr(1)); |
424 | rewriter.create<tensor::ParallelInsertSliceOp>( |
425 | terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, |
426 | resultStride); |
427 | } |
428 | |
429 | rewriter.replaceOp(loopOp, |
430 | newLoop->getResults().take_front(loopOp.getNumResults())); |
431 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
432 | } |
433 | |
434 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for |
435 | /// `LoopLikeOpInterface`, that just dispatches to the implementation for each |
436 | /// supported loop type. |
437 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( |
438 | LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, |
439 | ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { |
440 | return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( |
441 | loopLikeOp.getOperation()) |
442 | .Case<scf::ForOp, scf::ForallOp>( |
443 | [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
444 | return yieldTiledValuesAndReplaceLoop( |
445 | loopOp, rewriter, newInitOperands, yieldTiledValuesFn); |
446 | }) |
447 | .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
448 | return rewriter.notifyMatchFailure(loopOp, "unhandled loop type" ); |
449 | }); |
450 | } |
451 | |
452 | /// Method to add new init values to a loop nest. Updates `loops` in-place with |
453 | /// new loops that use the `newInitValues`. |
454 | /// The outer-loops are updated to yield the new result values of the inner |
455 | /// loop. For the innermost loop, the call back `getNewYields` is invoked to get |
456 | /// the additional values to yield form the innermost loop. |
457 | static LogicalResult addInitOperandsToLoopNest( |
458 | RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, |
459 | ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { |
460 | SmallVector<scf::ForOp> newLoops; |
461 | if (loops.empty()) |
462 | return success(); |
463 | OpBuilder::InsertionGuard g(rewriter); |
464 | rewriter.setInsertionPoint(loops.front()); |
465 | |
466 | SmallVector<Value> ivs; |
467 | for (auto &loop : loops.drop_back()) { |
468 | rewriter.setInsertionPoint(loop); |
469 | |
470 | // if loops.size() > 1 we assume that scf.for is used for the loops. |
471 | auto forLoop = cast<scf::ForOp>(loop.getOperation()); |
472 | |
473 | // Create a new loop with the new init values for this loop. |
474 | SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); |
475 | newInits.append(newInitValues.begin(), newInitValues.end()); |
476 | auto newLoop = rewriter.create<scf::ForOp>( |
477 | forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), |
478 | forLoop.getStep(), newInits, |
479 | [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); |
480 | |
481 | // Merge the body of the new loop with the body of the old loops. |
482 | SmallVector<Value> sourceBlockArgs; |
483 | sourceBlockArgs.push_back(newLoop.getInductionVar()); |
484 | auto newRegionIterArgs = newLoop.getRegionIterArgs(); |
485 | sourceBlockArgs.append( |
486 | newRegionIterArgs.begin(), |
487 | std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); |
488 | rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); |
489 | rewriter.replaceOp( |
490 | forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); |
491 | loop = newLoop; |
492 | ivs.push_back(newLoop.getInductionVar()); |
493 | newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); |
494 | } |
495 | |
496 | // Update the loop body of the innermost loop to get new yield values. |
497 | LoopLikeOpInterface innerMostLoop = loops.back(); |
498 | FailureOr<LoopLikeOpInterface> newInnerMostLoop = |
499 | yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, |
500 | getNewTiledYieldsFn); |
501 | |
502 | if (failed(newInnerMostLoop)) |
503 | return innerMostLoop.emitOpError("failed to return additional yields" ); |
504 | loops.back() = newInnerMostLoop.value(); |
505 | |
506 | // Make all other loops except the innermost loops yield the values returned |
507 | // by the inner loop. |
508 | for (auto [outerLoop, innerLoop] : |
509 | llvm::zip_equal(loops.drop_back(), loops.drop_front())) { |
510 | // Again assume that all the outer loops are scf.for operations. |
511 | auto outerForLoop = cast<scf::ForOp>(outerLoop); |
512 | auto outerLoopYield = |
513 | cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); |
514 | SmallVector<Value> newYields = |
515 | llvm::to_vector(outerLoopYield.getOperands()); |
516 | ValueRange additionalYields = |
517 | innerLoop->getResults().take_back(newInitValues.size()); |
518 | newYields.append(additionalYields.begin(), additionalYields.end()); |
519 | rewriter.setInsertionPoint(outerLoopYield); |
520 | rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); |
521 | } |
522 | return success(); |
523 | } |
524 | |
525 | /// Implementation of tiling transformation of `op` that implements the |
526 | /// `TilingInterface` using `scf.for` to iterate over the tiles. |
527 | FailureOr<scf::SCFTilingResult> |
528 | mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, |
529 | const scf::SCFTilingOptions &options) { |
530 | OpBuilder::InsertionGuard guard(rewriter); |
531 | rewriter.setInsertionPointAfter(op); |
532 | |
533 | if (!options.tileSizeComputationFunction) { |
534 | return rewriter.notifyMatchFailure( |
535 | op, "missing tile size computation function" ); |
536 | } |
537 | |
538 | // 1. Get the range of the loops that are represented by the operation. |
539 | SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); |
540 | size_t numLoops = iterationDomain.size(); |
541 | |
542 | // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" |
543 | // skips tiling a particular dimension. This convention is significantly |
544 | // simpler to handle instead of adjusting affine maps to account for missing |
545 | // dimensions. |
546 | SmallVector<OpFoldResult> tileSizes = |
547 | options.tileSizeComputationFunction(rewriter, op); |
548 | if (tileSizes.size() < iterationDomain.size()) { |
549 | auto zero = rewriter.getIndexAttr(0); |
550 | tileSizes.append(numLoops - tileSizes.size(), zero); |
551 | } |
552 | |
553 | // 3. If there is an interchange specified, permute the iteration domain and |
554 | // the tile sizes. |
555 | SmallVector<int64_t> interchangeVector; |
556 | if (!options.interchangeVector.empty()) { |
557 | interchangeVector = fillInterchangeVector(interchangeVector: options.interchangeVector, |
558 | iterationDomainSize: iterationDomain.size()); |
559 | } |
560 | if (!interchangeVector.empty()) { |
561 | if (!isPermutationVector(interchange: interchangeVector)) { |
562 | return rewriter.notifyMatchFailure( |
563 | op, "invalid intechange vector, not a permutation of the entire " |
564 | "iteration space" ); |
565 | } |
566 | |
567 | applyPermutationToVector(inVec&: iterationDomain, permutation: interchangeVector); |
568 | applyPermutationToVector(inVec&: tileSizes, permutation: interchangeVector); |
569 | } |
570 | |
571 | FailureOr<TilingResult> tilingResult; |
572 | // 4. Define the lambda function used later to generate the body of the |
573 | // innermost tiled loop. |
574 | YieldTiledValuesFn innerYieldTiledValuesFn = |
575 | [&](RewriterBase &rewriter, Location loc, ValueRange ivs, |
576 | ValueRange regionIterArgs, SmallVector<Value> &tiledResults, |
577 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
578 | SmallVector<SmallVector<OpFoldResult>> &resultSizes) |
579 | -> LogicalResult { |
580 | // 4a. Compute the `offsets` and `sizes` to use for tiling. |
581 | SmallVector<OpFoldResult> offsets, sizes; |
582 | { |
583 | int materializedLoopNum = 0; |
584 | for (auto [tileSize, loopRange] : |
585 | llvm::zip_equal(tileSizes, iterationDomain)) { |
586 | if (isConstantIntValue(tileSize, 0)) { |
587 | offsets.push_back(loopRange.offset); |
588 | sizes.push_back(loopRange.size); |
589 | continue; |
590 | } |
591 | Value iv = ivs[materializedLoopNum++]; |
592 | offsets.push_back(iv); |
593 | sizes.push_back( |
594 | getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); |
595 | } |
596 | } |
597 | |
598 | // 4b. If interchange was provided, apply inverse of the interchange |
599 | // to get back the offsets/sizes in the order to be specified. |
600 | if (!interchangeVector.empty()) { |
601 | auto inversePermutation = invertPermutationVector(permutation: interchangeVector); |
602 | applyPermutationToVector(inVec&: offsets, permutation: inversePermutation); |
603 | applyPermutationToVector(inVec&: sizes, permutation: inversePermutation); |
604 | } |
605 | |
606 | // 5. Generate the tiled implementation within the inner most loop. |
607 | |
608 | // 5a. Clone the operation within the loop body. |
609 | auto clonedOp = cast<TilingInterface>( |
610 | cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); |
611 | |
612 | // 5b. Early return cloned op if tiling is not happening. We can not return |
613 | // the original op because it could lead to |
614 | // `rewriter.replaceOp(op, op->getResults())` and users would get crash. |
615 | if (llvm::all_of(Range&: tileSizes, P: isZeroIndex)) { |
616 | tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); |
617 | tilingResult = |
618 | TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()}; |
619 | return success(); |
620 | } |
621 | |
622 | // 5c. Tile the cloned operation. |
623 | tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes); |
624 | if (failed(result: tilingResult)) { |
625 | rewriter.eraseOp(op: clonedOp); |
626 | return op.emitOpError("faild to tile operation" ); |
627 | } |
628 | |
629 | // 5d. Delete the cloned operation. |
630 | rewriter.eraseOp(op: clonedOp); |
631 | |
632 | // 5e. Compute the offsets at which the result values are to be inserted |
633 | // back into its destinations. |
634 | for (auto [index, tiledValue] : |
635 | llvm::enumerate(First&: tilingResult->tiledValues)) { |
636 | tiledResults.push_back(Elt: tiledValue); |
637 | SmallVector<OpFoldResult> resultOffset, resultSize; |
638 | if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes, |
639 | resultOffset, resultSize))) { |
640 | for (auto op : tilingResult->tiledOps) { |
641 | rewriter.eraseOp(op); |
642 | } |
643 | return rewriter.notifyMatchFailure( |
644 | op, "failed to get slice of result produced" ); |
645 | } |
646 | resultOffsets.emplace_back(Args: std::move(resultOffset)); |
647 | resultSizes.emplace_back(Args: std::move(resultSize)); |
648 | } |
649 | |
650 | return success(); |
651 | }; |
652 | |
653 | // 6. Find the destination tensors to use for the operation. |
654 | SmallVector<Value> destinationTensors; |
655 | if (failed(tensor::getOrCreateDestinations(b&: rewriter, loc: op.getLoc(), op: op, |
656 | result&: destinationTensors))) { |
657 | return rewriter.notifyMatchFailure(op, |
658 | "unable to create destination tensors" ); |
659 | } |
660 | |
661 | // 7. Generate the tiled loops nest using the callback defined above. |
662 | SmallVector<LoopLikeOpInterface> loops; |
663 | if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, |
664 | tileSizes, destinationTensors, |
665 | innerYieldTiledValuesFn, loops))) |
666 | return op.emitOpError("failed to generate tiling loops" ); |
667 | assert(succeeded(tilingResult) && |
668 | "expected tiling result to be computed after loop generation" ); |
669 | |
670 | // If loops are empty, the tiled op is used as the replacement for the untiled |
671 | // op. |
672 | if (loops.empty()) { |
673 | return scf::SCFTilingResult{tilingResult->tiledOps, loops, |
674 | tilingResult->tiledValues}; |
675 | } |
676 | |
677 | SmallVector<Value> replacements = llvm::map_to_vector( |
678 | loops.front()->getResults(), [](OpResult r) -> Value { return r; }); |
679 | return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements}; |
680 | } |
681 | |
682 | FailureOr<scf::SCFReductionTilingResult> |
683 | mlir::scf::tileReductionUsingScf(RewriterBase &b, |
684 | PartialReductionOpInterface op, |
685 | ArrayRef<OpFoldResult> tileSizes) { |
686 | Location loc = op.getLoc(); |
687 | // Ops implementing PartialReductionOpInterface are expected to implement |
688 | // TilingInterface. |
689 | auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); |
690 | SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); |
691 | auto tileSizesVector = llvm::to_vector(Range&: tileSizes); |
692 | if (tileSizesVector.size() < iterationDomain.size()) { |
693 | auto zero = b.getIndexAttr(0); |
694 | tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(), |
695 | zero); |
696 | } |
697 | if (op->getNumResults() != 1) |
698 | return b.notifyMatchFailure( |
699 | op, "don't support ops with multiple results for now" ); |
700 | SmallVector<utils::IteratorType> iterators = |
701 | tilingInterfaceOp.getLoopIteratorTypes(); |
702 | |
703 | SmallVector<int> reductionDims; |
704 | for (auto [idx, iteratorType] : |
705 | llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { |
706 | if (iteratorType == utils::IteratorType::reduction) |
707 | reductionDims.push_back(idx); |
708 | } |
709 | |
710 | // 2. create the inital tensor value. |
711 | FailureOr<Operation *> identityTensor = |
712 | op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector, |
713 | reductionDims); |
714 | if (failed(result: identityTensor)) |
715 | return b.notifyMatchFailure(op, |
716 | "cannot create a tensor of identity value." ); |
717 | |
718 | // 3. Define the callback to use for generating the inner most tile loop body. |
719 | Operation *parallelOp = nullptr; |
720 | auto innerYieldTiledValuesFn = |
721 | [&](RewriterBase &rewriter, Location loc, ValueRange ivs, |
722 | ValueRange regionIterArgs, SmallVector<Value> &tiledResult, |
723 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
724 | SmallVector<SmallVector<OpFoldResult>> &resultSizes) |
725 | -> LogicalResult { |
726 | SmallVector<OpFoldResult> offsets, sizes; |
727 | { |
728 | int materializedLoopNum = 0; |
729 | for (auto [tileSize, loopRange] : |
730 | llvm::zip_equal(tileSizesVector, iterationDomain)) { |
731 | if (isConstantIntValue(tileSize, 0)) { |
732 | offsets.push_back(loopRange.offset); |
733 | sizes.push_back(loopRange.size); |
734 | continue; |
735 | } |
736 | Value iv = ivs[materializedLoopNum++]; |
737 | offsets.push_back(iv); |
738 | sizes.push_back( |
739 | getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize)); |
740 | } |
741 | } |
742 | |
743 | // 4a. Clone the operation. |
744 | auto clonedOp = cast<PartialReductionOpInterface>( |
745 | cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); |
746 | |
747 | // 4b. Tile the cloned operation. |
748 | parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs, |
749 | offsets, sizes, reductionDims); |
750 | // 4c. Delete the cloned operation. |
751 | b.eraseOp(op: clonedOp); |
752 | |
753 | tiledResult.append(in_start: parallelOp->result_begin(), in_end: parallelOp->result_end()); |
754 | // 4d. Compute the offsets and sizes needed to insert the result of the |
755 | // tiled value back into destination before yielding the destination. |
756 | SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); |
757 | resultOffsets.emplace_back(Args: std::move(outOffsets)); |
758 | |
759 | SmallVector<OpFoldResult> outSizes; |
760 | for (size_t i = 0; i < offsets.size(); i++) { |
761 | outSizes.push_back( |
762 | Elt: tensor::getMixedSize(builder&: b, loc, value: parallelOp->getResult(idx: 0), dim: i)); |
763 | } |
764 | resultSizes.emplace_back(Args: std::move(outSizes)); |
765 | return success(); |
766 | }; |
767 | |
768 | // 5. Generate the tiled implementation using the destination tensors. |
769 | SmallVector<Value> destinationTensors = |
770 | llvm::map_to_vector(C: identityTensor.value()->getResults(), |
771 | F: [](OpResult res) -> Value { return res; }); |
772 | |
773 | SmallVector<LoopLikeOpInterface> loops; |
774 | scf::SCFTilingOptions options; |
775 | options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); |
776 | if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector, |
777 | destinationTensors, innerYieldTiledValuesFn, |
778 | loops))) |
779 | return b.notifyMatchFailure(op, "failed to tile for parallel reduction" ); |
780 | |
781 | SmallVector<Value> replacements = llvm::map_to_vector( |
782 | loops.front()->getResults(), [](OpResult r) -> Value { return r; }); |
783 | |
784 | // 5. Apply the merge reduction to combine all the partial values. |
785 | b.setInsertionPointAfter(*loops.begin()); |
786 | Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); |
787 | b.replaceOp(op, mergeOp->getResults()); |
788 | |
789 | SCFReductionTilingResult results; |
790 | results.initialOp = *identityTensor; |
791 | results.loops = loops; |
792 | results.parallelTiledOp = parallelOp; |
793 | results.mergeOp = mergeOp; |
794 | return results; |
795 | } |
796 | |
797 | //===----------------------------------------------------------------------===// |
798 | // tileConsumerAndFuseProducersUsingSCF implementation. |
799 | //===----------------------------------------------------------------------===// |
800 | |
801 | /// Return the untiled producer whose slice is used in a tiled consumer. The |
802 | /// method traverses the tile loop nest (`loops`) if needed, and returns the |
803 | /// `iter_args` of the outer most that is encountered. Traversing the iter_args |
804 | /// indicates that this is a destination operand of the consumer. If there was |
805 | /// no loop traversal needed, the second value of the returned tuple is empty. |
806 | static std::tuple<OpResult, std::optional<OpOperand *>> |
807 | getUntiledProducerFromSliceSource(OpOperand *source, |
808 | ArrayRef<LoopLikeOpInterface> loops) { |
809 | std::optional<OpOperand *> destinationIterArg; |
810 | auto loopIt = loops.rbegin(); |
811 | while (auto iterArg = dyn_cast<BlockArgument>(Val: source->get())) { |
812 | auto loop = *loopIt; |
813 | if (iterArg.getOwner()->getParentOp() != loop) |
814 | break; |
815 | source = loop.getTiedLoopInit(iterArg); |
816 | loopIt++; |
817 | } |
818 | if (loopIt == loops.rend()) |
819 | destinationIterArg = source; |
820 | return {dyn_cast<OpResult>(Val: source->get()), destinationIterArg}; |
821 | } |
822 | |
823 | /// Implementation of fusing producer of a single slice by computing the |
824 | /// slice of the producer in-place. |
825 | std::optional<scf::SCFFuseProducerOfSliceResult> |
826 | mlir::scf::tileAndFuseProducerOfSlice( |
827 | RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, |
828 | MutableArrayRef<LoopLikeOpInterface> loops) { |
829 | // 1. Get the producer of the source (potentially walking through |
830 | // `iter_args` of nested `scf.for`) |
831 | auto [fusableProducer, destinationInitArg] = |
832 | getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), |
833 | loops); |
834 | if (!fusableProducer) |
835 | return std::nullopt; |
836 | unsigned resultNumber = fusableProducer.getResultNumber(); |
837 | |
838 | OpBuilder::InsertionGuard g(rewriter); |
839 | rewriter.setInsertionPoint(candidateSliceOp); |
840 | |
841 | // 2. Clone the fused producer |
842 | // 2a. Compute the destination operands to use for the cloned operation. |
843 | SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; |
844 | Operation *fusableProducerOp = fusableProducer.getOwner(); |
845 | if (isa<DestinationStyleOpInterface>(fusableProducerOp) && |
846 | failed(tensor::getOrCreateDestinations( |
847 | rewriter, fusableProducerOp->getLoc(), fusableProducerOp, |
848 | origDestinationTensors))) |
849 | return std::nullopt; |
850 | |
851 | clonedOpDestinationTensors = origDestinationTensors; |
852 | if (destinationInitArg && |
853 | isa<DestinationStyleOpInterface>(fusableProducerOp)) { |
854 | // 2b. If the producer is also destination style, then to maintain the |
855 | // destination passing style, update the destination of the producer to be |
856 | // the source of the slice. |
857 | clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); |
858 | } |
859 | // 2c. Clone the fused producer. |
860 | Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( |
861 | rewriter, op: fusableProducerOp, newDestArgs: clonedOpDestinationTensors); |
862 | // 2d. Update the source of the candidateSlice to be the cloned producer. |
863 | // Easier to just clone the slice with different source since replacements |
864 | // and DCE of cloned ops becomes easier |
865 | SmallVector<Value> candidateSliceOpOperands = |
866 | llvm::to_vector(candidateSliceOp->getOperands()); |
867 | candidateSliceOpOperands[0] = clonedProducerOp->getResult(idx: resultNumber); |
868 | tensor::ExtractSliceOp clonedCandidateSliceOp = |
869 | mlir::clone(rewriter, candidateSliceOp, |
870 | candidateSliceOp->getResultTypes(), candidateSliceOpOperands); |
871 | |
872 | // 3. Generate the tiled implementation of the producer of the source |
873 | FailureOr<TilingResult> tileAndFuseResult = |
874 | tensor::replaceExtractSliceWithTiledProducer( |
875 | rewriter, clonedCandidateSliceOp, |
876 | clonedProducerOp->getResult(idx: resultNumber)); |
877 | if (failed(result: tileAndFuseResult)) |
878 | return std::nullopt; |
879 | // Note: Do not delete the candidateSliceOp, since its passed in from the |
880 | // caller. |
881 | rewriter.replaceAllUsesWith(candidateSliceOp, |
882 | tileAndFuseResult->tiledValues[0]); |
883 | rewriter.eraseOp(op: clonedCandidateSliceOp); |
884 | rewriter.eraseOp(op: clonedProducerOp); |
885 | |
886 | // 3. If the slice is for a destination operand, for example, |
887 | // |
888 | // ```mlir |
889 | // %0 = linalg.init |
890 | // %1 = linalg.fill .. outs(%0 : ) |
891 | // %2 = scf.for .. iter_args(%arg0 = %1) { |
892 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { |
893 | // %4 = tensor.extract_slice %arg1 [..] |
894 | // .. = linalg.matmul .. outs(%4 : ) |
895 | // } |
896 | // } |
897 | // ``` |
898 | // |
899 | // the IR is currently |
900 | // |
901 | // ``` |
902 | // %0 = linalg.init |
903 | // %1 = linalg.fill |
904 | // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { |
905 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { |
906 | // %4 = tensor.extract_slice %arg1[..] |
907 | // %5 = linalg.fill .. outs(%4 : ) |
908 | // .. = linalg.matmul .. outs(%5 : ) |
909 | // } |
910 | // } |
911 | // ``` |
912 | // |
913 | // The untiled `linalg.fill` is still used as the `init_value` since it |
914 | // was originally a destination operand of the untiled `linalg.matmul`. |
915 | // When fusing an operand that is a destination operand, the iter_arg of |
916 | // the outer most loop should be changed to use the destination of the |
917 | // fused operation. With this the IR will be. |
918 | // |
919 | // ``` |
920 | // %0 = linalg.init |
921 | // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { |
922 | // %2 = scf.for .. iter_args(%arg1 = %arg0) { |
923 | // %3 = tensor.extract_slice %arg1[..] |
924 | // %4 = linalg.fill .. outs(%3 : ) |
925 | // .. = linalg.matmul .. outs(%4 : ) |
926 | // } |
927 | // } |
928 | // ``` |
929 | if (destinationInitArg && |
930 | isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { |
931 | loops.front() |
932 | ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] |
933 | .set(origDestinationTensors[resultNumber]); |
934 | } |
935 | return scf::SCFFuseProducerOfSliceResult{fusableProducer, |
936 | tileAndFuseResult->tiledValues[0], |
937 | tileAndFuseResult->tiledOps}; |
938 | } |
939 | |
940 | /// Reconstruct the fused producer from within the tiled-and-fused code. |
941 | LogicalResult mlir::scf::yieldReplacementForFusedProducer( |
942 | RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, |
943 | scf::SCFFuseProducerOfSliceResult fusedProducerInfo, |
944 | MutableArrayRef<LoopLikeOpInterface> loops) { |
945 | if (loops.empty()) |
946 | return success(); |
947 | |
948 | OpResult fusableProducer = fusedProducerInfo.origProducer; |
949 | Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer; |
950 | FailureOr<Value> initValue = tensor::getOrCreateDestination( |
951 | b&: rewriter, loc: fusableProducer.getOwner()->getLoc(), opResult: fusableProducer); |
952 | if (succeeded(result: initValue)) { |
953 | |
954 | YieldTiledValuesFn newYieldValuesFn = |
955 | [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, |
956 | ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, |
957 | SmallVector<SmallVector<OpFoldResult>> &tiledOffset, |
958 | SmallVector<SmallVector<OpFoldResult>> &tiledSizes) |
959 | -> LogicalResult { |
960 | OpBuilder::InsertionGuard g(innerRewriter); |
961 | if (auto tiledDestStyleOp = |
962 | tiledAndFusedProducer |
963 | .getDefiningOp<DestinationStyleOpInterface>()) { |
964 | rewriter.setInsertionPoint(tiledDestStyleOp); |
965 | Value newRegionArg = newRegionIterArgs.back(); |
966 | auto destSlice = rewriter.create<tensor::ExtractSliceOp>( |
967 | sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(), |
968 | sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); |
969 | unsigned resultNumber = fusableProducer.getResultNumber(); |
970 | rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { |
971 | tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); |
972 | }); |
973 | } |
974 | Block *block = rewriter.getInsertionPoint()->getBlock(); |
975 | rewriter.setInsertionPoint(block->getTerminator()); |
976 | tiledResult.push_back(Elt: fusedProducerInfo.tiledAndFusedProducer); |
977 | tiledOffset.emplace_back(sliceOp.getMixedOffsets()); |
978 | tiledSizes.emplace_back(sliceOp.getMixedSizes()); |
979 | return success(); |
980 | }; |
981 | |
982 | return addInitOperandsToLoopNest(rewriter, loops, |
983 | SmallVector<Value>{initValue.value()}, |
984 | newYieldValuesFn); |
985 | } |
986 | return success(); |
987 | } |
988 | |
989 | /// Implementation of tile consumer and fuse producer greedily. |
990 | FailureOr<scf::SCFTileAndFuseResult> |
991 | mlir::scf::tileConsumerAndFuseProducersUsingSCF( |
992 | RewriterBase &rewriter, TilingInterface consumer, |
993 | const scf::SCFTileAndFuseOptions &options) { |
994 | // This transformation is only valid for ops that return values (i.e. not |
995 | // valid to use with operations that have memref operands). |
996 | if (!consumer->getNumResults()) { |
997 | return rewriter.notifyMatchFailure( |
998 | consumer, "invalid pattern for op with no results" ); |
999 | } |
1000 | |
1001 | // 1. First tile the consumer. |
1002 | SetVector<Operation *> fusedProducers, tiledAndFusedOps; |
1003 | llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum; |
1004 | |
1005 | FailureOr<scf::SCFTilingResult> tilingResult = |
1006 | tileUsingSCF(rewriter, consumer, options.tilingOptions); |
1007 | |
1008 | if (failed(result: tilingResult)) |
1009 | return rewriter.notifyMatchFailure(consumer, "failed to tile consumer" ); |
1010 | for (auto *tiledOp : tilingResult->tiledOps) |
1011 | tiledAndFusedOps.insert(tiledOp); |
1012 | |
1013 | // If there are no loops generated, fusion is immaterial. |
1014 | auto &loops = tilingResult->loops; |
1015 | if (loops.empty()) { |
1016 | DenseMap<Value, Value> replacements; |
1017 | for (auto [origVal, replacement] : |
1018 | llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { |
1019 | replacements[origVal] = replacement; |
1020 | } |
1021 | return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, |
1022 | replacements}; |
1023 | } |
1024 | |
1025 | // To keep track of replacements for now just record the map from the original |
1026 | // untiled value to the result number of the for loop. Since the loop gets |
1027 | // potentially replaced during fusion, keeping the value directly wont work. |
1028 | DenseMap<Value, size_t> origValToResultNumber; |
1029 | for (auto [index, result] : llvm::enumerate(consumer->getResults())) { |
1030 | origValToResultNumber[result] = index; |
1031 | } |
1032 | |
1033 | // 2. Typically, the operands of the tiled operation are slices of the |
1034 | // operands of the untiled operation. These are expressed in IR using |
1035 | // `tensor.extract_slice` operations with source being the operands of the |
1036 | // untiled operation. Create a worklist of these `tensor.extract_slice` |
1037 | // operations. If the producers of the source of the `tensor.extract_slice` |
1038 | // can be tiled such that the tiled value is generated in-place, that |
1039 | // effectively tiles + fuses the operations. |
1040 | auto addCandidateSlices = [](Operation *fusedOp, |
1041 | std::deque<tensor::ExtractSliceOp> &candidates) { |
1042 | for (Value operand : fusedOp->getOperands()) |
1043 | if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) |
1044 | candidates.push_back(sliceOp); |
1045 | }; |
1046 | |
1047 | std::deque<tensor::ExtractSliceOp> candidates; |
1048 | addCandidateSlices(tiledAndFusedOps.back(), candidates); |
1049 | OpBuilder::InsertionGuard g(rewriter); |
1050 | while (!candidates.empty()) { |
1051 | // Traverse the slices in BFS fashion. |
1052 | tensor::ExtractSliceOp candidateSliceOp = candidates.front(); |
1053 | candidates.pop_front(); |
1054 | |
1055 | // Find the original producer of the slice. |
1056 | auto [fusableProducer, destinationInitArg] = |
1057 | getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), |
1058 | loops); |
1059 | if (!fusableProducer) |
1060 | continue; |
1061 | |
1062 | auto [fuseSlice, yieldReplacement] = options.fusionControlFn( |
1063 | candidateSliceOp, fusableProducer, destinationInitArg.has_value()); |
1064 | if (!fuseSlice) |
1065 | continue; |
1066 | |
1067 | // The operands of the fused producer might themselved be slices of |
1068 | // values produced by operations that implement the `TilingInterface`. |
1069 | // Add these operations to the worklist. |
1070 | std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = |
1071 | tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops); |
1072 | if (!fusedResult) |
1073 | continue; |
1074 | |
1075 | if (yieldReplacement) { |
1076 | if (failed(yieldReplacementForFusedProducer( |
1077 | rewriter, candidateSliceOp, fusedResult.value(), loops))) { |
1078 | return rewriter.notifyMatchFailure( |
1079 | fusableProducer.getOwner(), "failed to replacement value for this " |
1080 | "oepration from within the tiled loop" ); |
1081 | } |
1082 | origValToResultNumber[fusableProducer] = |
1083 | loops.front()->getNumResults() - 1; |
1084 | } |
1085 | |
1086 | if (Operation *tiledAndFusedOp = |
1087 | fusedResult->tiledAndFusedProducer.getDefiningOp()) { |
1088 | fusedProducers.insert(X: fusedResult->origProducer.getDefiningOp()); |
1089 | tiledAndFusedOps.insert(X: tiledAndFusedOp); |
1090 | addCandidateSlices(tiledAndFusedOp, candidates); |
1091 | } |
1092 | } |
1093 | |
1094 | DenseMap<Value, Value> replacements; |
1095 | for (auto [origVal, resultNumber] : origValToResultNumber) { |
1096 | replacements[origVal] = loops.front()->getResult(resultNumber); |
1097 | } |
1098 | |
1099 | return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, |
1100 | replacements}; |
1101 | } |
1102 | |
1103 | //===----------------------------------------------------------------------===// |
1104 | // lowerToLoopsUsingSCFForOp implementation. |
1105 | //===----------------------------------------------------------------------===// |
1106 | |
1107 | FailureOr<SmallVector<scf::ForOp>> |
1108 | mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, |
1109 | TilingInterface op) { |
1110 | // TODO: Handle cases where the op has results if needed. |
1111 | if (op->getNumResults() > 0) { |
1112 | return rewriter.notifyMatchFailure( |
1113 | op, "unable to lower to loops operations with return values" ); |
1114 | } |
1115 | |
1116 | SmallVector<Range> domain = op.getIterationDomain(rewriter); |
1117 | SmallVector<Value> ivs; |
1118 | SmallVector<scf::ForOp> loops; |
1119 | Location loc = op.getLoc(); |
1120 | for (auto loopRange : domain) { |
1121 | Value offsetVal = |
1122 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); |
1123 | Value sizeVal = |
1124 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); |
1125 | Value strideVal = |
1126 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); |
1127 | auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, |
1128 | strideVal, ValueRange{}); |
1129 | loops.push_back(loop); |
1130 | ivs.push_back(loop.getInductionVar()); |
1131 | rewriter.setInsertionPoint(loop.getBody()->getTerminator()); |
1132 | } |
1133 | if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { |
1134 | return failure(); |
1135 | } |
1136 | return loops; |
1137 | } |
1138 | |