1//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
10#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
11
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Interfaces/LoopLikeInterface.h"
16#include "mlir/Interfaces/TilingInterface.h"
17#include "mlir/Interfaces/ViewLikeInterface.h"
18#include "mlir/Rewrite/FrozenRewritePatternSet.h"
19
20#include <deque>
21
22namespace mlir {
23class Operation;
24class RewriterBase;
25class TilingInterface;
26} // namespace mlir
27
28namespace mlir {
29namespace scf {
30
31using SCFTileSizeComputationFunction =
32 std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
33
34/// Options to use to control tiling.
35struct SCFTilingOptions {
36 /// Computation function that returns the tile sizes to use for each loop.
37 /// Returning a tile size of zero implies no tiling for that loop. If the
38 /// size of the returned vector is smaller than the number of loops, the inner
39 /// loops are not tiled. If the size of the returned vector is larger, then
40 /// the vector is truncated to number of loops.
41 SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
42
43 SCFTilingOptions &
44 setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
45 tileSizeComputationFunction = std::move(fun);
46 return *this;
47 }
48 /// Convenience function to set the `tileSizeComputationFunction` to a
49 /// function that computes tile sizes at the point they are needed. Allows
50 /// proper interaction with folding.
51 SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
52
53 /// Computation function that returns the number of threads to use for
54 /// each loop. Returning a num threads of zero implies no tiling for that
55 /// loop. If the size of the returned vector is smaller than the number of
56 /// loops, the inner loops are not tiled. If the size of the returned vector
57 /// is larger, then the vector is truncated to number of loops. Note: This
58 /// option is only supported with loopType set to `LoopType::ForallOp`. If the
59 /// tile size function is not specified while the num threads computation is,
60 /// then the tile size is determined automatically to map at most one tile per
61 /// thread.
62 SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
63
64 SCFTilingOptions &
65 setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
66 numThreadsComputationFunction = std::move(fun);
67 return *this;
68 }
69 /// Convenience function to set the `numThreadsComputationFunction` to a
70 /// function that computes num threads at the point they are needed.
71 SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
72
73 /// The interchange vector to reorder the tiled loops.
74 SmallVector<int64_t> interchangeVector = {};
75 SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
76 interchangeVector = llvm::to_vector(Range&: interchange);
77 return *this;
78 }
79
80 /// Specify which loop construct to use for tile and fuse.
81 enum class LoopType { ForOp, ForallOp };
82 LoopType loopType = LoopType::ForOp;
83 SCFTilingOptions &setLoopType(LoopType type) {
84 loopType = type;
85 return *this;
86 }
87
88 /// Specify how reduction dimensions should be tiled.
89 ///
90 /// Tiling can be thought of as splitting a dimension into 2 and materializing
91 /// the outer dimension as a loop:
92 ///
93 /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
94 ///
95 /// For parallel dimensions, the split can only happen in one way, with both
96 /// dimensions being parallel. For reduction dimensions however, there is a
97 /// choice in how we split the reduction dimension. This enum exposes this
98 /// choice.
99 enum class ReductionTilingStrategy {
100 // [reduction] -> [reduction1, reduction2]
101 // -> loop[reduction1] { [reduction2] }
102 FullReduction,
103 // [reduction] -> [reduction1, parallel2]
104 // -> loop[reduction1] { [parallel2] }; merge[reduction1]
105 PartialReductionOuterReduction,
106 // [reduction] -> [parallel1, reduction2]
107 // -> loop[parallel1] { [reduction2] }; merge[parallel1]
108 PartialReductionOuterParallel
109 };
110 ReductionTilingStrategy reductionStrategy =
111 ReductionTilingStrategy::FullReduction;
112 SCFTilingOptions &
113 setReductionTilingStrategy(ReductionTilingStrategy strategy) {
114 reductionStrategy = strategy;
115 return *this;
116 }
117
118 /// Specify mapping of loops to devices. This is only respected when the loop
119 /// constructs support such a mapping (like `scf.forall`). Will be ignored
120 /// when using loop constructs that dont support such a mapping (like
121 /// `scf.for`)
122 SmallVector<Attribute> mappingVector = {};
123 SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
124 mappingVector = llvm::to_vector(Range&: mapping);
125 return *this;
126 }
127};
128
129/// Transformation information returned after tiling.
130struct SCFTilingResult {
131 /// Tiled operations that are generated during tiling. The order does not
132 /// matter except the last op. The replacements are expected to be the results
133 /// of the last op.
134 SmallVector<Operation *> tiledOps;
135 /// The initial destination values passed to the tiled operations.
136 SmallVector<Value> initialValues;
137 /// The `scf.for` operations that iterate over the tiles.
138 SmallVector<LoopLikeOpInterface> loops;
139 /// Values to use as replacements for the untiled op. Is the same size as the
140 /// number of results of the untiled op.
141 SmallVector<Value> replacements;
142 /// Slices generated after tiling that can be used for fusing with the tiled
143 /// producer.
144 SmallVector<Operation *> generatedSlices;
145 /// In cases where there as an additional merge step after tiling
146 /// return the merged ops after tiling. This list is empty when reduction
147 /// tiling strategy is
148 /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
149 SmallVector<Operation *> mergeOps;
150};
151
152/// Method to tile an op that implements the `TilingInterface` using
153/// `scf.for` for iterating over the tiles.
154FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
155 TilingInterface op,
156 const SCFTilingOptions &options);
157
158/// Options used to control tile + fuse.
159struct SCFTileAndFuseOptions {
160 /// The tiling options used to control the tiling of the consumer.
161 SCFTilingOptions tilingOptions;
162 SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) {
163 tilingOptions = options;
164 return *this;
165 }
166
167 /// Control function to check if a slice needs to be fused or not,
168 /// The control function receives
169 /// 1) the slice along which fusion is to be done,
170 /// 2) the producer value that is to be fused
171 /// 3) a boolean value set to `true` if the fusion is from
172 /// a destination operand.
173 /// The control function returns an `std::optiona<ControlFnResult>`.
174 /// If the return value is `std::nullopt`, that implies no fusion
175 /// is to be performed along that slice.
176 struct ControlFnResult {
177 /// Set to true if the loop nest has to return a replacement value
178 /// for the fused producer.
179 bool yieldProducerReplacement = false;
180 };
181 using ControlFnTy = std::function<std::optional<ControlFnResult>(
182 tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
183 bool isDestinationOperand)>;
184 /// The default control function implements greedy fusion without yielding
185 /// a replacement for any of the fused results.
186 ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
187 bool) -> std::optional<ControlFnResult> {
188 return ControlFnResult{};
189 };
190 SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
191 fusionControlFn = controlFn;
192 return *this;
193 }
194
195 /// An optional set of rewrite patterns to apply to the results of tiling
196 /// before fusion. This will track deleted and newly inserted
197 /// `tensor.extract_slice` ops and update the worklist.
198 std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
199};
200
201/// Fuse the producer of the source of `candidateSliceOp` by computing the
202/// required slice of the producer in-place. Note that the method
203/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
204/// value but does not delete the slice operation.
205struct SCFFuseProducerOfSliceResult {
206 OpResult origProducer; // Original untiled producer.
207 Value tiledAndFusedProducer; // Tile and fused producer value.
208 SmallVector<Operation *> tiledOps;
209 SmallVector<Operation *> generatedSlices;
210};
211std::optional<SCFFuseProducerOfSliceResult>
212tileAndFuseProducerOfSlice(RewriterBase &rewriter,
213 tensor::ExtractSliceOp candidateSliceOp,
214 MutableArrayRef<LoopLikeOpInterface> loops);
215
216/// Reconstruct the fused producer from within the tiled-and-fused code. Based
217/// on the slice of the producer computed in place it is possible that within
218/// the loop nest same slice of the producer is computed multiple times. It is
219/// in general not possible to recompute the value of the fused producer from
220/// the tiled loop code in such cases. For the cases where no slice of the
221/// producer is computed in a redundant fashion it is possible to reconstruct
222/// the value of the original producer from within the tiled loop. It is upto
223/// the caller to ensure that the producer is not computed redundantly within
224/// the tiled loop nest. For example, consider
225///
226/// ```mlir
227/// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
228/// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32>
229/// ```
230///
231/// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR
232/// is,
233///
234/// ```mlir
235/// %t1_0 = scf.for .... iter_args(%arg0 = ...) {
236/// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) {
237/// ...
238/// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
239/// %t1_3 = linalg.matmul ins(%t1_2, ...)
240/// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ...
241/// scf.yield %t1_4
242/// }
243/// scf.yield %t1_1
244/// }
245/// ```
246///
247/// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead
248/// if `%1` were tiled only along the rows, the resultant code would be
249///
250/// ```mlir
251/// %t2_0 = scf.for .... iter_args(%arg0 = ...) {
252/// ...
253/// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
254/// %t2_2 = linalg.matmul ins(%t2_1, ...)
255/// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ...
256/// scf.yield %t2_3
257/// }
258/// ```
259///
260/// Here there is no intersection in the different slices of `%t2_1` computed
261/// across iterations of the `scf.for`. In such cases, the value of the original
262/// `%0` can be reconstructed from within the loop body. This is useful in cases
263/// where `%0` had other uses as well. If not reconstructed from within the loop
264/// body, uses of `%0` could not be replaced, making it still live and the
265/// fusion immaterial.
266///
267/// The @param `yieldResultNumber` decides which result would be yield. If not
268/// given, yield all `opResult` of fused producer.
269///
270/// The method returns the list of new slices added during the process (which
271/// can be used to fuse along).
272FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
273 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
274 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
275 MutableArrayRef<LoopLikeOpInterface> loops,
276 ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
277
278/// Transformation information returned after tile and fuse.
279struct SCFTileAndFuseResult {
280 /// List of untiled operations that were fused with the tiled consumer.
281 llvm::SetVector<Operation *> fusedProducers;
282 /// List of tiled and fused operations generated. The first one in this list
283 /// is guaranteed to be the tiled operations generated during tiling of the
284 /// generated operation.
285 llvm::SetVector<Operation *> tiledAndFusedOps;
286 /// The `scf.for` operations that iterate over the tiles.
287 SmallVector<LoopLikeOpInterface> loops;
288 /// The replacement values to use for the tiled and fused operations.
289 llvm::DenseMap<Value, Value> replacements;
290};
291
292/// Method to tile and fuse a sequence of operations, by tiling the consumer
293/// and fusing its producers. Note that this assumes that it is valid to
294/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
295/// to ensure that the tile sizes provided make this fusion valid.
296///
297/// For example, for the following sequence
298///
299/// ```mlir
300/// %0 =
301/// %1 = linalg.fill ... outs(%0 : ... )
302/// %2 = linalg.matmul ... outs(%1 : ...) ...
303/// ```
304///
305/// it is legal to fuse the fill with the matmul only if the matmul is tiled
306/// along the parallel dimensions and not the reduction dimension, i.e. the tile
307/// size for the reduction dimension should be 0. The resulting fused
308/// transformation is
309///
310/// ```mlir
311/// %1 = scf.for ... iter_args(%arg0 = %0)
312/// %2 = tensor.extract_slice %arg0
313/// %3 = linalg.fill .. outs(%2 : ... )
314/// %4 = linalg.matmul .. outs(%3 : ...)
315/// }
316/// ```
317FailureOr<SCFTileAndFuseResult>
318tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
319 TilingInterface consumer,
320 const SCFTileAndFuseOptions &options);
321
322/// Fuse the consumer of the source of `candidateSliceOp` by computing the
323/// required slice of the consumer in-place. Note that the method
324/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
325/// value but does not delete the slice operation.
326struct SCFFuseConsumerOfSliceResult {
327 OpOperand *origConsumerOperand; // Original untiled consumer's operand.
328 OpOperand
329 *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
330 SmallVector<Operation *> tiledOps;
331};
332FailureOr<scf::SCFFuseConsumerOfSliceResult>
333tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
334 MutableArrayRef<LoopLikeOpInterface> loops);
335
336/// Method to lower an `op` that implements the `TilingInterface` to
337/// loops/scalars.
338FailureOr<SmallVector<scf::ForOp>>
339lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
340
341/// Method to tile a reduction and generate a parallel op within a serial loop.
342/// Each of the partial reductions are calculated in parallel. Then after the
343/// loop all the partial reduction are merged into a final reduction.
344/// For example for the following sequence
345///
346/// ```mlir
347/// %0 = linalg.generic %in ["parallel", "reduction"]
348/// : tensor<7x9xf32> -> tensor<7xf32>
349/// ```
350///
351/// into:
352///
353/// ```mlir
354/// %0 = linalg.fill ... : tensor<7x4xf32>
355/// %1 = scf.for ... iter_args(%arg0 = %0)
356/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32>
357/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
358/// %4 = linalg.generic %2, %3 ["parallel", "parallel"]
359/// : tensor<7x?xf32> -> tensor<7x?xf32>
360/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32>
361/// }
362/// %6 = linalg.generic %1 ["parallel", "reduction"]
363/// : tensor<7x4xf32> -> tensor<7xf32>
364/// ```
365FailureOr<scf::SCFTilingResult>
366tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
367 ArrayRef<OpFoldResult> tileSizes);
368
369} // namespace scf
370} // namespace mlir
371
372#endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
373

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h