1 | //===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
10 | #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
11 | |
12 | #include "mlir/Dialect/SCF/IR/SCF.h" |
13 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
14 | #include "mlir/IR/PatternMatch.h" |
15 | #include "mlir/Interfaces/LoopLikeInterface.h" |
16 | #include "mlir/Interfaces/TilingInterface.h" |
17 | |
18 | #include <deque> |
19 | |
20 | namespace mlir { |
21 | class Operation; |
22 | class RewriterBase; |
23 | class TilingInterface; |
24 | } // namespace mlir |
25 | |
26 | namespace mlir { |
27 | namespace scf { |
28 | |
29 | using SCFTileSizeComputationFunction = |
30 | std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>; |
31 | |
32 | /// Options to use to control tiling. |
33 | struct SCFTilingOptions { |
34 | /// Computation function that returns the tile sizes for each operation. |
35 | /// Delayed construction of constant tile sizes should occur to interoperate |
36 | /// with folding. |
37 | SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr; |
38 | |
39 | SCFTilingOptions & |
40 | setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) { |
41 | tileSizeComputationFunction = std::move(fun); |
42 | return *this; |
43 | } |
44 | /// Convenience function to set the `tileSizeComputationFunction` to a |
45 | /// function that computes tile sizes at the point they are needed. Allows |
46 | /// proper interaction with folding. |
47 | SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts); |
48 | |
49 | /// The interchange vector to reorder the tiled loops. |
50 | SmallVector<int64_t> interchangeVector = {}; |
51 | SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) { |
52 | interchangeVector = llvm::to_vector(Range&: interchange); |
53 | return *this; |
54 | } |
55 | |
56 | /// Specify which loop construct to use for tile and fuse. |
57 | enum class LoopType { ForOp, ForallOp }; |
58 | LoopType loopType = LoopType::ForOp; |
59 | SCFTilingOptions &setLoopType(LoopType type) { |
60 | loopType = type; |
61 | return *this; |
62 | } |
63 | |
64 | /// Specify mapping of loops to devices. This is only respected when the loop |
65 | /// constructs support such a mapping (like `scf.forall`). Will be ignored |
66 | /// when using loop constructs that dont support such a mapping (like |
67 | /// `scf.for`) |
68 | SmallVector<Attribute> mappingVector = {}; |
69 | SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) { |
70 | mappingVector = llvm::map_to_vector( |
71 | mapping, [](auto attr) -> Attribute { return attr; }); |
72 | return *this; |
73 | } |
74 | }; |
75 | |
76 | /// Transformation information returned after tiling. |
77 | struct SCFTilingResult { |
78 | /// Tiled operations that are generated during tiling. The order does not |
79 | /// matter except the last op. The replacements are expected to be the results |
80 | /// of the last op. |
81 | SmallVector<Operation *> tiledOps; |
82 | /// The `scf.for` operations that iterate over the tiles. |
83 | SmallVector<LoopLikeOpInterface> loops; |
84 | /// Values to use as replacements for the untiled op. Is the same size as the |
85 | /// number of results of the untiled op. |
86 | SmallVector<Value> replacements; |
87 | }; |
88 | |
89 | /// Method to tile an op that implements the `TilingInterface` using |
90 | /// `scf.for` for iterating over the tiles. |
91 | FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter, |
92 | TilingInterface op, |
93 | const SCFTilingOptions &options); |
94 | |
95 | /// Options used to control tile + fuse. |
96 | struct SCFTileAndFuseOptions { |
97 | /// The tiling options used to control the tiling of the consumer. |
98 | SCFTilingOptions tilingOptions; |
99 | SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) { |
100 | tilingOptions = options; |
101 | return *this; |
102 | } |
103 | |
104 | /// Control function to check if a slice needs to be fused or not, |
105 | /// The control function receives |
106 | /// 1) the slice along which fusion is to be done, |
107 | /// 2) the producer value that is to be fused |
108 | /// 3) a boolean value set to `true` if the fusion is from |
109 | /// a destination operand. |
110 | /// It retuns two booleans |
111 | /// - returns `true` if the fusion should be done through the candidate slice |
112 | /// - returns `true` if a replacement for the fused producer needs to be |
113 | /// yielded from within the tiled loop. Note that it is valid to return |
114 | /// `true` only if the slice fused is disjoint across all iterations of the |
115 | /// tiled loop. It is up to the caller to ensure that this is true for the |
116 | /// fused producers. |
117 | using ControlFnTy = std::function<std::tuple<bool, bool>( |
118 | tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, |
119 | bool isDestinationOperand)>; |
120 | ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) { |
121 | return std::make_tuple(true, false); |
122 | }; |
123 | SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) { |
124 | fusionControlFn = controlFn; |
125 | return *this; |
126 | } |
127 | }; |
128 | |
129 | /// Fuse the producer of the source of `candidateSliceOp` by computing the |
130 | /// required slice of the producer in-place. Note that the method |
131 | /// replaces the uses of `candidateSliceOp` with the tiled and fused producer |
132 | /// value but does not delete the slice operation. |
133 | struct SCFFuseProducerOfSliceResult { |
134 | OpResult origProducer; // Original untiled producer. |
135 | Value tiledAndFusedProducer; // Tile and fused producer value. |
136 | SmallVector<Operation *> tiledOps; |
137 | }; |
138 | std::optional<SCFFuseProducerOfSliceResult> |
139 | tileAndFuseProducerOfSlice(RewriterBase &rewriter, |
140 | tensor::ExtractSliceOp candidateSliceOp, |
141 | MutableArrayRef<LoopLikeOpInterface> loops); |
142 | |
143 | /// Reconstruct the fused producer from within the tiled-and-fused code. Based |
144 | /// on the slice of the producer computed in place it is possible that within |
145 | /// the loop nest same slice of the producer is computed multiple times. It is |
146 | /// in general not possible to recompute the value of the fused producer from |
147 | /// the tiled loop code in such cases. For the cases where no slice of the |
148 | /// producer is computed in a redundant fashion it is possible to reconstruct |
149 | /// the value of the original producer from within the tiled loop. It is upto |
150 | /// the caller to ensure that the producer is not computed redundantly within |
151 | /// the tiled loop nest. For example, consider |
152 | /// |
153 | /// ```mlir |
154 | /// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
155 | /// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32> |
156 | /// ``` |
157 | /// |
158 | /// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR |
159 | /// is, |
160 | /// |
161 | /// ```mlir |
162 | /// %t1_0 = scf.for .... iter_args(%arg0 = ...) { |
163 | /// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) { |
164 | /// ... |
165 | /// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
166 | /// %t1_3 = linalg.matmul ins(%t1_2, ...) |
167 | /// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ... |
168 | /// scf.yield %t1_4 |
169 | /// } |
170 | /// scf.yield %t1_1 |
171 | /// } |
172 | /// ``` |
173 | /// |
174 | /// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead |
175 | /// if `%1` were tiled only along the rows, the resultant code would be |
176 | /// |
177 | /// ```mlir |
178 | /// %t2_0 = scf.for .... iter_args(%arg0 = ...) { |
179 | /// ... |
180 | /// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
181 | /// %t2_2 = linalg.matmul ins(%t2_1, ...) |
182 | /// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ... |
183 | /// scf.yield %t2_3 |
184 | /// } |
185 | /// ``` |
186 | /// |
187 | /// Here there is no intersection in the different slices of `%t2_1` computed |
188 | /// across iterations of the `scf.for`. In such cases, the value of the original |
189 | /// `%0` can be reconstructed from within the loop body. This is useful in cases |
190 | /// where `%0` had other uses as well. If not reconstructed from within the loop |
191 | /// body, uses of `%0` could not be replaced, making it still live and the |
192 | /// fusion immaterial. |
193 | LogicalResult yieldReplacementForFusedProducer( |
194 | RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, |
195 | scf::SCFFuseProducerOfSliceResult fusedProducerInfo, |
196 | MutableArrayRef<LoopLikeOpInterface> loops); |
197 | |
198 | /// Transformation information returned after tile and fuse. |
199 | struct SCFTileAndFuseResult { |
200 | /// List of untiled operations that were fused with the tiled consumer. |
201 | llvm::SetVector<Operation *> fusedProducers; |
202 | /// List of tiled and fused operations generated. The first one in this list |
203 | /// is guaranteed to be the tiled operations generated during tiling of the |
204 | /// generated operation. |
205 | llvm::SetVector<Operation *> tiledAndFusedOps; |
206 | /// The `scf.for` operations that iterate over the tiles. |
207 | SmallVector<LoopLikeOpInterface> loops; |
208 | /// The replacement values to use for the tiled and fused operations. |
209 | llvm::DenseMap<Value, Value> replacements; |
210 | }; |
211 | |
212 | /// Method to tile and fuse a sequence of operations, by tiling the consumer |
213 | /// and fusing its producers. Note that this assumes that it is valid to |
214 | /// tile+fuse the producer into the innermost tiled loop. Its up to the caller |
215 | /// to ensure that the tile sizes provided make this fusion valid. |
216 | /// |
217 | /// For example, for the following sequence |
218 | /// |
219 | /// ```mlir |
220 | /// %0 = |
221 | /// %1 = linalg.fill ... outs(%0 : ... ) |
222 | /// %2 = linalg.matmul ... outs(%1 : ...) ... |
223 | /// ``` |
224 | /// |
225 | /// it is legal to fuse the fill with the matmul only if the matmul is tiled |
226 | /// along the parallel dimensions and not the reduction dimension, i.e. the tile |
227 | /// size for the reduction dimension should be 0. The resulting fused |
228 | /// transformation is |
229 | /// |
230 | /// ```mlir |
231 | /// %1 = scf.for ... iter_args(%arg0 = %0) |
232 | /// %2 = tensor.extract_slice %arg0 |
233 | /// %3 = linalg.fill .. outs(%2 : ... ) |
234 | /// %4 = linalg.matmul .. outs(%3 : ...) |
235 | /// } |
236 | /// ``` |
237 | FailureOr<SCFTileAndFuseResult> |
238 | tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, |
239 | TilingInterface consumer, |
240 | const SCFTileAndFuseOptions &options); |
241 | |
242 | /// Method to lower an `op` that implements the `TilingInterface` to |
243 | /// loops/scalars. |
244 | FailureOr<SmallVector<scf::ForOp>> |
245 | lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); |
246 | |
247 | /// Transformation information returned after reduction tiling. |
248 | struct SCFReductionTilingResult { |
249 | /// The partial reduction tiled op generated. |
250 | Operation *parallelTiledOp; |
251 | /// The final reduction operation merging all the partial reductions. |
252 | Operation *mergeOp; |
253 | /// Initial op |
254 | Operation *initialOp; |
255 | /// The loop operations that iterate over the tiles. |
256 | SmallVector<LoopLikeOpInterface> loops; |
257 | }; |
258 | |
259 | /// Method to tile a reduction and generate a parallel op within a serial loop. |
260 | /// Each of the partial reductions are calculated in parallel. Then after the |
261 | /// loop all the partial reduction are merged into a final reduction. |
262 | /// For example for the following sequence |
263 | /// |
264 | /// ```mlir |
265 | /// %0 = linalg.generic %in ["parallel", "reduction"] |
266 | /// : tensor<7x9xf32> -> tensor<7xf32> |
267 | /// ``` |
268 | /// |
269 | /// into: |
270 | /// |
271 | /// ```mlir |
272 | /// %0 = linalg.fill ... : tensor<7x4xf32> |
273 | /// %1 = scf.for ... iter_args(%arg0 = %0) |
274 | /// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32> |
275 | /// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> |
276 | /// %4 = linalg.generic %2, %3 ["parallel", "parallel"] |
277 | /// : tensor<7x?xf32> -> tensor<7x?xf32> |
278 | /// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32> |
279 | /// } |
280 | /// %6 = linalg.generic %1 ["parallel", "reduction"] |
281 | /// : tensor<7x4xf32> -> tensor<7xf32> |
282 | /// ``` |
283 | FailureOr<scf::SCFReductionTilingResult> |
284 | tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, |
285 | ArrayRef<OpFoldResult> tileSize); |
286 | |
287 | } // namespace scf |
288 | } // namespace mlir |
289 | |
290 | #endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
291 | |