1//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
10#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
11
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Interfaces/LoopLikeInterface.h"
16#include "mlir/Interfaces/TilingInterface.h"
17#include "mlir/Interfaces/ViewLikeInterface.h"
18#include "mlir/Rewrite/FrozenRewritePatternSet.h"
19
20#include <deque>
21
22namespace mlir {
23class Operation;
24class RewriterBase;
25class TilingInterface;
26} // namespace mlir
27
28namespace mlir {
29namespace scf {
30
31using SCFTileSizeComputationFunction =
32 std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
33
34/// Options to use to control tiling.
35struct SCFTilingOptions {
36 /// Computation function that returns the tile sizes to use for each loop.
37 /// Returning a tile size of zero implies no tiling for that loop. If the
38 /// size of the returned vector is smaller than the number of loops, the inner
39 /// loops are not tiled. If the size of the returned vector is larger, then
40 /// the vector is truncated to number of loops.
41 SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
42
43 SCFTilingOptions &
44 setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
45 tileSizeComputationFunction = std::move(fun);
46 return *this;
47 }
48 /// Convenience function to set the `tileSizeComputationFunction` to a
49 /// function that computes tile sizes at the point they are needed. Allows
50 /// proper interaction with folding.
51 SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
52
53 /// Computation function that returns the number of threads to use for
54 /// each loop. Returning a num threads of zero implies no tiling for that
55 /// loop. If the size of the returned vector is smaller than the number of
56 /// loops, the inner loops are not tiled. If the size of the returned vector
57 /// is larger, then the vector is truncated to number of loops. Note: This
58 /// option is only supported with loopType set to `LoopType::ForallOp`. If the
59 /// tile size function is not specified while the num threads computation is,
60 /// then the tile size is determined automatically to map at most one tile per
61 /// thread.
62 SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
63
64 SCFTilingOptions &
65 setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
66 numThreadsComputationFunction = std::move(fun);
67 return *this;
68 }
69 /// Convenience function to set the `numThreadsComputationFunction` to a
70 /// function that computes num threads at the point they are needed.
71 SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
72
73 /// The interchange vector to reorder the tiled loops.
74 SmallVector<int64_t> interchangeVector = {};
75 SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
76 interchangeVector = llvm::to_vector(Range&: interchange);
77 return *this;
78 }
79
80 /// Specify which loop construct to use for tile and fuse.
81 enum class LoopType { ForOp, ForallOp };
82 LoopType loopType = LoopType::ForOp;
83 SCFTilingOptions &setLoopType(LoopType type) {
84 loopType = type;
85 return *this;
86 }
87
88 /// Specify mapping of loops to devices. This is only respected when the loop
89 /// constructs support such a mapping (like `scf.forall`). Will be ignored
90 /// when using loop constructs that dont support such a mapping (like
91 /// `scf.for`)
92 SmallVector<Attribute> mappingVector = {};
93 SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
94 mappingVector = llvm::to_vector(Range&: mapping);
95 return *this;
96 }
97
98 //-------------------------------------------------------------------------//
99 // Options related reduction tiling
100 //-------------------------------------------------------------------------//
101
102 /// Specify how reduction dimensions should be tiled.
103 ReductionTilingStrategy reductionStrategy =
104 ReductionTilingStrategy::FullReduction;
105 SCFTilingOptions &
106 setReductionTilingStrategy(ReductionTilingStrategy strategy) {
107 reductionStrategy = strategy;
108 return *this;
109 }
110
111 /// Specify the reduction dimensions to be tiled. Note that this needs to be
112 /// specified. If left unspecified, then none of the reduction dimensions are
113 /// tiled.
114 SetVector<unsigned> reductionDims;
115 SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
116 reductionDims.clear();
117 reductionDims.insert(Start: dims.begin(), End: dims.end());
118 return *this;
119 }
120};
121
122/// Transformation information returned after tiling.
123struct SCFTilingResult {
124 /// Tiled operations that are generated during tiling. The order does not
125 /// matter except the last op. The replacements are expected to be the results
126 /// of the last op.
127 SmallVector<Operation *> tiledOps;
128 /// The initial destination values passed to the tiled operations.
129 SmallVector<Value> initialValues;
130 /// The `scf.for` operations that iterate over the tiles.
131 SmallVector<LoopLikeOpInterface> loops;
132 /// Values to use as replacements for the untiled op. Is the same size as the
133 /// number of results of the untiled op.
134 SmallVector<Value> replacements;
135 /// Slices generated after tiling that can be used for fusing with the tiled
136 /// producer.
137 SmallVector<Operation *> generatedSlices;
138 /// In cases where there as an additional merge step after tiling
139 /// return the merged ops after tiling. This list is empty when reduction
140 /// tiling strategy is
141 /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
142 SmallVector<Operation *> mergeOps;
143};
144
145/// Method to tile an op that implements the `TilingInterface` using
146/// `scf.for` for iterating over the tiles.
147FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
148 TilingInterface op,
149 const SCFTilingOptions &options);
150
151/// Options used to control tile + fuse.
152struct SCFTileAndFuseOptions {
153 /// The tiling options used to control the tiling of the consumer.
154 SCFTilingOptions tilingOptions;
155 SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) {
156 tilingOptions = options;
157 return *this;
158 }
159
160 /// Control function to check if a slice needs to be fused or not,
161 /// The control function receives
162 /// 1) the slice along which fusion is to be done,
163 /// 2) the producer value that is to be fused
164 /// 3) a boolean value set to `true` if the fusion is from
165 /// a destination operand.
166 /// The control function returns an `std::optiona<ControlFnResult>`.
167 /// If the return value is `std::nullopt`, that implies no fusion
168 /// is to be performed along that slice.
169 struct ControlFnResult {
170 /// Set to true if the loop nest has to return a replacement value
171 /// for the fused producer.
172 bool yieldProducerReplacement = false;
173 };
174 using ControlFnTy = std::function<std::optional<ControlFnResult>(
175 tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
176 bool isDestinationOperand)>;
177 /// The default control function implements greedy fusion without yielding
178 /// a replacement for any of the fused results.
179 ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
180 bool) -> std::optional<ControlFnResult> {
181 return ControlFnResult{};
182 };
183 SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
184 fusionControlFn = controlFn;
185 return *this;
186 }
187
188 /// An optional set of rewrite patterns to apply to the results of tiling
189 /// before fusion. This will track deleted and newly inserted
190 /// `tensor.extract_slice` ops and update the worklist.
191 std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
192};
193
194/// Fuse the producer of the source of `candidateSliceOp` by computing the
195/// required slice of the producer in-place. Note that the method
196/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
197/// value but does not delete the slice operation.
198struct SCFFuseProducerOfSliceResult {
199 OpResult origProducer; // Original untiled producer.
200 Value tiledAndFusedProducer; // Tile and fused producer value.
201 SmallVector<Operation *> tiledOps;
202 SmallVector<Operation *> generatedSlices;
203};
204std::optional<SCFFuseProducerOfSliceResult>
205tileAndFuseProducerOfSlice(RewriterBase &rewriter,
206 tensor::ExtractSliceOp candidateSliceOp,
207 MutableArrayRef<LoopLikeOpInterface> loops);
208
209/// Reconstruct the fused producer from within the tiled-and-fused code. Based
210/// on the slice of the producer computed in place it is possible that within
211/// the loop nest same slice of the producer is computed multiple times. It is
212/// in general not possible to recompute the value of the fused producer from
213/// the tiled loop code in such cases. For the cases where no slice of the
214/// producer is computed in a redundant fashion it is possible to reconstruct
215/// the value of the original producer from within the tiled loop. It is upto
216/// the caller to ensure that the producer is not computed redundantly within
217/// the tiled loop nest. For example, consider
218///
219/// ```mlir
220/// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
221/// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32>
222/// ```
223///
224/// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR
225/// is,
226///
227/// ```mlir
228/// %t1_0 = scf.for .... iter_args(%arg0 = ...) {
229/// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) {
230/// ...
231/// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
232/// %t1_3 = linalg.matmul ins(%t1_2, ...)
233/// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ...
234/// scf.yield %t1_4
235/// }
236/// scf.yield %t1_1
237/// }
238/// ```
239///
240/// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead
241/// if `%1` were tiled only along the rows, the resultant code would be
242///
243/// ```mlir
244/// %t2_0 = scf.for .... iter_args(%arg0 = ...) {
245/// ...
246/// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
247/// %t2_2 = linalg.matmul ins(%t2_1, ...)
248/// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ...
249/// scf.yield %t2_3
250/// }
251/// ```
252///
253/// Here there is no intersection in the different slices of `%t2_1` computed
254/// across iterations of the `scf.for`. In such cases, the value of the original
255/// `%0` can be reconstructed from within the loop body. This is useful in cases
256/// where `%0` had other uses as well. If not reconstructed from within the loop
257/// body, uses of `%0` could not be replaced, making it still live and the
258/// fusion immaterial.
259///
260/// The @param `yieldResultNumber` decides which result would be yield. If not
261/// given, yield all `opResult` of fused producer.
262///
263/// The method returns the list of new slices added during the process (which
264/// can be used to fuse along).
265FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
266 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
267 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
268 MutableArrayRef<LoopLikeOpInterface> loops,
269 ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
270
271/// Transformation information returned after tile and fuse.
272struct SCFTileAndFuseResult {
273 /// List of untiled operations that were fused with the tiled consumer.
274 llvm::SetVector<Operation *> fusedProducers;
275 /// List of tiled and fused operations generated. The first element is always
276 /// the tiled version of the original consumer operation processed by
277 /// `tileConsumerAndFuseProducersUsingSCF`, followed by any operations that
278 /// were fused with it.
279 llvm::SetVector<Operation *> tiledAndFusedOps;
280 /// The `scf.for` operations that iterate over the tiles.
281 SmallVector<LoopLikeOpInterface> loops;
282 /// The replacement values to use for the tiled and fused operations.
283 llvm::DenseMap<Value, Value> replacements;
284};
285
286/// Method to tile and fuse a sequence of operations, by tiling the consumer
287/// and fusing its producers. Note that this assumes that it is valid to
288/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
289/// to ensure that the tile sizes provided make this fusion valid.
290///
291/// For example, for the following sequence
292///
293/// ```mlir
294/// %0 =
295/// %1 = linalg.fill ... outs(%0 : ... )
296/// %2 = linalg.matmul ... outs(%1 : ...) ...
297/// ```
298///
299/// it is legal to fuse the fill with the matmul only if the matmul is tiled
300/// along the parallel dimensions and not the reduction dimension, i.e. the tile
301/// size for the reduction dimension should be 0. The resulting fused
302/// transformation is
303///
304/// ```mlir
305/// %1 = scf.for ... iter_args(%arg0 = %0)
306/// %2 = tensor.extract_slice %arg0
307/// %3 = linalg.fill .. outs(%2 : ... )
308/// %4 = linalg.matmul .. outs(%3 : ...)
309/// }
310/// ```
311FailureOr<SCFTileAndFuseResult>
312tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
313 TilingInterface consumer,
314 const SCFTileAndFuseOptions &options);
315
316/// Fuse the consumer `candidateSlices` by computing the required slice of the
317/// consumer in-place. All the entries of `candidateSlices` are expected to map
318/// to the same consumer. The method returns an error if the consumer cannot be
319/// tiled in a manner that is consistent for all the passed slices. Note that
320/// the method replaces the uses of `candidateSlices` with the tiled and fused
321/// consumer value but does not delete the slice operations.
322struct SCFFuseConsumerOfSliceResult {
323 // Original untiled consumer operands.
324 SmallVector<OpOperand *> origConsumerOperands;
325 // Tiled and fused consumer operands.
326 SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
327 SmallVector<Operation *> tiledOps;
328};
329FailureOr<scf::SCFFuseConsumerOfSliceResult>
330tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
331 ArrayRef<Operation *> candidateSlices,
332 MutableArrayRef<LoopLikeOpInterface> loops);
333
334/// Method to lower an `op` that implements the `TilingInterface` to
335/// loops/scalars.
336FailureOr<SmallVector<scf::ForOp>>
337lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
338
339/// Method to tile a reduction and generate a parallel op within a serial loop.
340/// Each of the partial reductions are calculated in parallel. Then after the
341/// loop all the partial reduction are merged into a final reduction.
342/// For example for the following sequence
343///
344/// ```mlir
345/// %0 = linalg.generic %in ["parallel", "reduction"]
346/// : tensor<7x9xf32> -> tensor<7xf32>
347/// ```
348///
349/// into:
350///
351/// ```mlir
352/// %0 = linalg.fill ... : tensor<7x4xf32>
353/// %1 = scf.for ... iter_args(%arg0 = %0)
354/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32>
355/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
356/// %4 = linalg.generic %2, %3 ["parallel", "parallel"]
357/// : tensor<7x?xf32> -> tensor<7x?xf32>
358/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32>
359/// }
360/// %6 = linalg.generic %1 ["parallel", "reduction"]
361/// : tensor<7x4xf32> -> tensor<7xf32>
362/// ```
363FailureOr<scf::SCFTilingResult>
364tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
365 ArrayRef<OpFoldResult> tileSizes);
366
367} // namespace scf
368} // namespace mlir
369
370#endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
371

source code of mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h