1//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
10#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
11
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Interfaces/LoopLikeInterface.h"
16#include "mlir/Interfaces/TilingInterface.h"
17
18#include <deque>
19
20namespace mlir {
21class Operation;
22class RewriterBase;
23class TilingInterface;
24} // namespace mlir
25
26namespace mlir {
27namespace scf {
28
29using SCFTileSizeComputationFunction =
30 std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
31
32/// Options to use to control tiling.
33struct SCFTilingOptions {
34 /// Computation function that returns the tile sizes for each operation.
35 /// Delayed construction of constant tile sizes should occur to interoperate
36 /// with folding.
37 SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
38
39 SCFTilingOptions &
40 setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
41 tileSizeComputationFunction = std::move(fun);
42 return *this;
43 }
44 /// Convenience function to set the `tileSizeComputationFunction` to a
45 /// function that computes tile sizes at the point they are needed. Allows
46 /// proper interaction with folding.
47 SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
48
49 /// The interchange vector to reorder the tiled loops.
50 SmallVector<int64_t> interchangeVector = {};
51 SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
52 interchangeVector = llvm::to_vector(Range&: interchange);
53 return *this;
54 }
55
56 /// Specify which loop construct to use for tile and fuse.
57 enum class LoopType { ForOp, ForallOp };
58 LoopType loopType = LoopType::ForOp;
59 SCFTilingOptions &setLoopType(LoopType type) {
60 loopType = type;
61 return *this;
62 }
63
64 /// Specify mapping of loops to devices. This is only respected when the loop
65 /// constructs support such a mapping (like `scf.forall`). Will be ignored
66 /// when using loop constructs that dont support such a mapping (like
67 /// `scf.for`)
68 SmallVector<Attribute> mappingVector = {};
69 SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
70 mappingVector = llvm::map_to_vector(
71 mapping, [](auto attr) -> Attribute { return attr; });
72 return *this;
73 }
74};
75
76/// Transformation information returned after tiling.
77struct SCFTilingResult {
78 /// Tiled operations that are generated during tiling. The order does not
79 /// matter except the last op. The replacements are expected to be the results
80 /// of the last op.
81 SmallVector<Operation *> tiledOps;
82 /// The `scf.for` operations that iterate over the tiles.
83 SmallVector<LoopLikeOpInterface> loops;
84 /// Values to use as replacements for the untiled op. Is the same size as the
85 /// number of results of the untiled op.
86 SmallVector<Value> replacements;
87};
88
89/// Method to tile an op that implements the `TilingInterface` using
90/// `scf.for` for iterating over the tiles.
91FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
92 TilingInterface op,
93 const SCFTilingOptions &options);
94
95/// Options used to control tile + fuse.
96struct SCFTileAndFuseOptions {
97 /// The tiling options used to control the tiling of the consumer.
98 SCFTilingOptions tilingOptions;
99 SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) {
100 tilingOptions = options;
101 return *this;
102 }
103
104 /// Control function to check if a slice needs to be fused or not,
105 /// The control function receives
106 /// 1) the slice along which fusion is to be done,
107 /// 2) the producer value that is to be fused
108 /// 3) a boolean value set to `true` if the fusion is from
109 /// a destination operand.
110 /// It retuns two booleans
111 /// - returns `true` if the fusion should be done through the candidate slice
112 /// - returns `true` if a replacement for the fused producer needs to be
113 /// yielded from within the tiled loop. Note that it is valid to return
114 /// `true` only if the slice fused is disjoint across all iterations of the
115 /// tiled loop. It is up to the caller to ensure that this is true for the
116 /// fused producers.
117 using ControlFnTy = std::function<std::tuple<bool, bool>(
118 tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
119 bool isDestinationOperand)>;
120 ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
121 return std::make_tuple(true, false);
122 };
123 SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
124 fusionControlFn = controlFn;
125 return *this;
126 }
127};
128
129/// Fuse the producer of the source of `candidateSliceOp` by computing the
130/// required slice of the producer in-place. Note that the method
131/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
132/// value but does not delete the slice operation.
133struct SCFFuseProducerOfSliceResult {
134 OpResult origProducer; // Original untiled producer.
135 Value tiledAndFusedProducer; // Tile and fused producer value.
136 SmallVector<Operation *> tiledOps;
137};
138std::optional<SCFFuseProducerOfSliceResult>
139tileAndFuseProducerOfSlice(RewriterBase &rewriter,
140 tensor::ExtractSliceOp candidateSliceOp,
141 MutableArrayRef<LoopLikeOpInterface> loops);
142
143/// Reconstruct the fused producer from within the tiled-and-fused code. Based
144/// on the slice of the producer computed in place it is possible that within
145/// the loop nest same slice of the producer is computed multiple times. It is
146/// in general not possible to recompute the value of the fused producer from
147/// the tiled loop code in such cases. For the cases where no slice of the
148/// producer is computed in a redundant fashion it is possible to reconstruct
149/// the value of the original producer from within the tiled loop. It is upto
150/// the caller to ensure that the producer is not computed redundantly within
151/// the tiled loop nest. For example, consider
152///
153/// ```mlir
154/// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
155/// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32>
156/// ```
157///
158/// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR
159/// is,
160///
161/// ```mlir
162/// %t1_0 = scf.for .... iter_args(%arg0 = ...) {
163/// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) {
164/// ...
165/// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
166/// %t1_3 = linalg.matmul ins(%t1_2, ...)
167/// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ...
168/// scf.yield %t1_4
169/// }
170/// scf.yield %t1_1
171/// }
172/// ```
173///
174/// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead
175/// if `%1` were tiled only along the rows, the resultant code would be
176///
177/// ```mlir
178/// %t2_0 = scf.for .... iter_args(%arg0 = ...) {
179/// ...
180/// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
181/// %t2_2 = linalg.matmul ins(%t2_1, ...)
182/// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ...
183/// scf.yield %t2_3
184/// }
185/// ```
186///
187/// Here there is no intersection in the different slices of `%t2_1` computed
188/// across iterations of the `scf.for`. In such cases, the value of the original
189/// `%0` can be reconstructed from within the loop body. This is useful in cases
190/// where `%0` had other uses as well. If not reconstructed from within the loop
191/// body, uses of `%0` could not be replaced, making it still live and the
192/// fusion immaterial.
193LogicalResult yieldReplacementForFusedProducer(
194 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
195 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
196 MutableArrayRef<LoopLikeOpInterface> loops);
197
198/// Transformation information returned after tile and fuse.
199struct SCFTileAndFuseResult {
200 /// List of untiled operations that were fused with the tiled consumer.
201 llvm::SetVector<Operation *> fusedProducers;
202 /// List of tiled and fused operations generated. The first one in this list
203 /// is guaranteed to be the tiled operations generated during tiling of the
204 /// generated operation.
205 llvm::SetVector<Operation *> tiledAndFusedOps;
206 /// The `scf.for` operations that iterate over the tiles.
207 SmallVector<LoopLikeOpInterface> loops;
208 /// The replacement values to use for the tiled and fused operations.
209 llvm::DenseMap<Value, Value> replacements;
210};
211
212/// Method to tile and fuse a sequence of operations, by tiling the consumer
213/// and fusing its producers. Note that this assumes that it is valid to
214/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
215/// to ensure that the tile sizes provided make this fusion valid.
216///
217/// For example, for the following sequence
218///
219/// ```mlir
220/// %0 =
221/// %1 = linalg.fill ... outs(%0 : ... )
222/// %2 = linalg.matmul ... outs(%1 : ...) ...
223/// ```
224///
225/// it is legal to fuse the fill with the matmul only if the matmul is tiled
226/// along the parallel dimensions and not the reduction dimension, i.e. the tile
227/// size for the reduction dimension should be 0. The resulting fused
228/// transformation is
229///
230/// ```mlir
231/// %1 = scf.for ... iter_args(%arg0 = %0)
232/// %2 = tensor.extract_slice %arg0
233/// %3 = linalg.fill .. outs(%2 : ... )
234/// %4 = linalg.matmul .. outs(%3 : ...)
235/// }
236/// ```
237FailureOr<SCFTileAndFuseResult>
238tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
239 TilingInterface consumer,
240 const SCFTileAndFuseOptions &options);
241
242/// Method to lower an `op` that implements the `TilingInterface` to
243/// loops/scalars.
244FailureOr<SmallVector<scf::ForOp>>
245lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
246
247/// Transformation information returned after reduction tiling.
248struct SCFReductionTilingResult {
249 /// The partial reduction tiled op generated.
250 Operation *parallelTiledOp;
251 /// The final reduction operation merging all the partial reductions.
252 Operation *mergeOp;
253 /// Initial op
254 Operation *initialOp;
255 /// The loop operations that iterate over the tiles.
256 SmallVector<LoopLikeOpInterface> loops;
257};
258
259/// Method to tile a reduction and generate a parallel op within a serial loop.
260/// Each of the partial reductions are calculated in parallel. Then after the
261/// loop all the partial reduction are merged into a final reduction.
262/// For example for the following sequence
263///
264/// ```mlir
265/// %0 = linalg.generic %in ["parallel", "reduction"]
266/// : tensor<7x9xf32> -> tensor<7xf32>
267/// ```
268///
269/// into:
270///
271/// ```mlir
272/// %0 = linalg.fill ... : tensor<7x4xf32>
273/// %1 = scf.for ... iter_args(%arg0 = %0)
274/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32>
275/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
276/// %4 = linalg.generic %2, %3 ["parallel", "parallel"]
277/// : tensor<7x?xf32> -> tensor<7x?xf32>
278/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32>
279/// }
280/// %6 = linalg.generic %1 ["parallel", "reduction"]
281/// : tensor<7x4xf32> -> tensor<7xf32>
282/// ```
283FailureOr<scf::SCFReductionTilingResult>
284tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
285 ArrayRef<OpFoldResult> tileSize);
286
287} // namespace scf
288} // namespace mlir
289
290#endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
291

source code of mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h