| 1 | //===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
| 10 | #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
| 11 | |
| 12 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 13 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| 14 | #include "mlir/IR/PatternMatch.h" |
| 15 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 16 | #include "mlir/Interfaces/TilingInterface.h" |
| 17 | #include "mlir/Interfaces/ViewLikeInterface.h" |
| 18 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 19 | |
| 20 | #include <deque> |
| 21 | |
| 22 | namespace mlir { |
| 23 | class Operation; |
| 24 | class RewriterBase; |
| 25 | class TilingInterface; |
| 26 | } // namespace mlir |
| 27 | |
| 28 | namespace mlir { |
| 29 | namespace scf { |
| 30 | |
| 31 | using SCFTileSizeComputationFunction = |
| 32 | std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>; |
| 33 | |
| 34 | /// Options to use to control tiling. |
| 35 | struct SCFTilingOptions { |
| 36 | /// Computation function that returns the tile sizes to use for each loop. |
| 37 | /// Returning a tile size of zero implies no tiling for that loop. If the |
| 38 | /// size of the returned vector is smaller than the number of loops, the inner |
| 39 | /// loops are not tiled. If the size of the returned vector is larger, then |
| 40 | /// the vector is truncated to number of loops. |
| 41 | SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr; |
| 42 | |
| 43 | SCFTilingOptions & |
| 44 | setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) { |
| 45 | tileSizeComputationFunction = std::move(fun); |
| 46 | return *this; |
| 47 | } |
| 48 | /// Convenience function to set the `tileSizeComputationFunction` to a |
| 49 | /// function that computes tile sizes at the point they are needed. Allows |
| 50 | /// proper interaction with folding. |
| 51 | SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes); |
| 52 | |
| 53 | /// Computation function that returns the number of threads to use for |
| 54 | /// each loop. Returning a num threads of zero implies no tiling for that |
| 55 | /// loop. If the size of the returned vector is smaller than the number of |
| 56 | /// loops, the inner loops are not tiled. If the size of the returned vector |
| 57 | /// is larger, then the vector is truncated to number of loops. Note: This |
| 58 | /// option is only supported with loopType set to `LoopType::ForallOp`. If the |
| 59 | /// tile size function is not specified while the num threads computation is, |
| 60 | /// then the tile size is determined automatically to map at most one tile per |
| 61 | /// thread. |
| 62 | SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr; |
| 63 | |
| 64 | SCFTilingOptions & |
| 65 | setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) { |
| 66 | numThreadsComputationFunction = std::move(fun); |
| 67 | return *this; |
| 68 | } |
| 69 | /// Convenience function to set the `numThreadsComputationFunction` to a |
| 70 | /// function that computes num threads at the point they are needed. |
| 71 | SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads); |
| 72 | |
| 73 | /// The interchange vector to reorder the tiled loops. |
| 74 | SmallVector<int64_t> interchangeVector = {}; |
| 75 | SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) { |
| 76 | interchangeVector = llvm::to_vector(Range&: interchange); |
| 77 | return *this; |
| 78 | } |
| 79 | |
| 80 | /// Specify which loop construct to use for tile and fuse. |
| 81 | enum class LoopType { ForOp, ForallOp }; |
| 82 | LoopType loopType = LoopType::ForOp; |
| 83 | SCFTilingOptions &setLoopType(LoopType type) { |
| 84 | loopType = type; |
| 85 | return *this; |
| 86 | } |
| 87 | |
| 88 | /// Specify how reduction dimensions should be tiled. |
| 89 | /// |
| 90 | /// Tiling can be thought of as splitting a dimension into 2 and materializing |
| 91 | /// the outer dimension as a loop: |
| 92 | /// |
| 93 | /// op[original] -> op[original / x, x] -> loop[original] { op[x] } |
| 94 | /// |
| 95 | /// For parallel dimensions, the split can only happen in one way, with both |
| 96 | /// dimensions being parallel. For reduction dimensions however, there is a |
| 97 | /// choice in how we split the reduction dimension. This enum exposes this |
| 98 | /// choice. |
| 99 | enum class ReductionTilingStrategy { |
| 100 | // [reduction] -> [reduction1, reduction2] |
| 101 | // -> loop[reduction1] { [reduction2] } |
| 102 | FullReduction, |
| 103 | // [reduction] -> [reduction1, parallel2] |
| 104 | // -> loop[reduction1] { [parallel2] }; merge[reduction1] |
| 105 | PartialReductionOuterReduction, |
| 106 | // [reduction] -> [parallel1, reduction2] |
| 107 | // -> loop[parallel1] { [reduction2] }; merge[parallel1] |
| 108 | PartialReductionOuterParallel |
| 109 | }; |
| 110 | ReductionTilingStrategy reductionStrategy = |
| 111 | ReductionTilingStrategy::FullReduction; |
| 112 | SCFTilingOptions & |
| 113 | setReductionTilingStrategy(ReductionTilingStrategy strategy) { |
| 114 | reductionStrategy = strategy; |
| 115 | return *this; |
| 116 | } |
| 117 | |
| 118 | /// Specify mapping of loops to devices. This is only respected when the loop |
| 119 | /// constructs support such a mapping (like `scf.forall`). Will be ignored |
| 120 | /// when using loop constructs that dont support such a mapping (like |
| 121 | /// `scf.for`) |
| 122 | SmallVector<Attribute> mappingVector = {}; |
| 123 | SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) { |
| 124 | mappingVector = llvm::to_vector(Range&: mapping); |
| 125 | return *this; |
| 126 | } |
| 127 | }; |
| 128 | |
| 129 | /// Transformation information returned after tiling. |
| 130 | struct SCFTilingResult { |
| 131 | /// Tiled operations that are generated during tiling. The order does not |
| 132 | /// matter except the last op. The replacements are expected to be the results |
| 133 | /// of the last op. |
| 134 | SmallVector<Operation *> tiledOps; |
| 135 | /// The initial destination values passed to the tiled operations. |
| 136 | SmallVector<Value> initialValues; |
| 137 | /// The `scf.for` operations that iterate over the tiles. |
| 138 | SmallVector<LoopLikeOpInterface> loops; |
| 139 | /// Values to use as replacements for the untiled op. Is the same size as the |
| 140 | /// number of results of the untiled op. |
| 141 | SmallVector<Value> replacements; |
| 142 | /// Slices generated after tiling that can be used for fusing with the tiled |
| 143 | /// producer. |
| 144 | SmallVector<Operation *> generatedSlices; |
| 145 | /// In cases where there as an additional merge step after tiling |
| 146 | /// return the merged ops after tiling. This list is empty when reduction |
| 147 | /// tiling strategy is |
| 148 | /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction. |
| 149 | SmallVector<Operation *> mergeOps; |
| 150 | }; |
| 151 | |
| 152 | /// Method to tile an op that implements the `TilingInterface` using |
| 153 | /// `scf.for` for iterating over the tiles. |
| 154 | FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter, |
| 155 | TilingInterface op, |
| 156 | const SCFTilingOptions &options); |
| 157 | |
| 158 | /// Options used to control tile + fuse. |
| 159 | struct SCFTileAndFuseOptions { |
| 160 | /// The tiling options used to control the tiling of the consumer. |
| 161 | SCFTilingOptions tilingOptions; |
| 162 | SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) { |
| 163 | tilingOptions = options; |
| 164 | return *this; |
| 165 | } |
| 166 | |
| 167 | /// Control function to check if a slice needs to be fused or not, |
| 168 | /// The control function receives |
| 169 | /// 1) the slice along which fusion is to be done, |
| 170 | /// 2) the producer value that is to be fused |
| 171 | /// 3) a boolean value set to `true` if the fusion is from |
| 172 | /// a destination operand. |
| 173 | /// The control function returns an `std::optiona<ControlFnResult>`. |
| 174 | /// If the return value is `std::nullopt`, that implies no fusion |
| 175 | /// is to be performed along that slice. |
| 176 | struct ControlFnResult { |
| 177 | /// Set to true if the loop nest has to return a replacement value |
| 178 | /// for the fused producer. |
| 179 | bool yieldProducerReplacement = false; |
| 180 | }; |
| 181 | using ControlFnTy = std::function<std::optional<ControlFnResult>( |
| 182 | tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, |
| 183 | bool isDestinationOperand)>; |
| 184 | /// The default control function implements greedy fusion without yielding |
| 185 | /// a replacement for any of the fused results. |
| 186 | ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, |
| 187 | bool) -> std::optional<ControlFnResult> { |
| 188 | return ControlFnResult{}; |
| 189 | }; |
| 190 | SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) { |
| 191 | fusionControlFn = controlFn; |
| 192 | return *this; |
| 193 | } |
| 194 | |
| 195 | /// An optional set of rewrite patterns to apply to the results of tiling |
| 196 | /// before fusion. This will track deleted and newly inserted |
| 197 | /// `tensor.extract_slice` ops and update the worklist. |
| 198 | std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt; |
| 199 | }; |
| 200 | |
| 201 | /// Fuse the producer of the source of `candidateSliceOp` by computing the |
| 202 | /// required slice of the producer in-place. Note that the method |
| 203 | /// replaces the uses of `candidateSliceOp` with the tiled and fused producer |
| 204 | /// value but does not delete the slice operation. |
| 205 | struct SCFFuseProducerOfSliceResult { |
| 206 | OpResult origProducer; // Original untiled producer. |
| 207 | Value tiledAndFusedProducer; // Tile and fused producer value. |
| 208 | SmallVector<Operation *> tiledOps; |
| 209 | SmallVector<Operation *> generatedSlices; |
| 210 | }; |
| 211 | std::optional<SCFFuseProducerOfSliceResult> |
| 212 | tileAndFuseProducerOfSlice(RewriterBase &rewriter, |
| 213 | tensor::ExtractSliceOp candidateSliceOp, |
| 214 | MutableArrayRef<LoopLikeOpInterface> loops); |
| 215 | |
| 216 | /// Reconstruct the fused producer from within the tiled-and-fused code. Based |
| 217 | /// on the slice of the producer computed in place it is possible that within |
| 218 | /// the loop nest same slice of the producer is computed multiple times. It is |
| 219 | /// in general not possible to recompute the value of the fused producer from |
| 220 | /// the tiled loop code in such cases. For the cases where no slice of the |
| 221 | /// producer is computed in a redundant fashion it is possible to reconstruct |
| 222 | /// the value of the original producer from within the tiled loop. It is upto |
| 223 | /// the caller to ensure that the producer is not computed redundantly within |
| 224 | /// the tiled loop nest. For example, consider |
| 225 | /// |
| 226 | /// ```mlir |
| 227 | /// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
| 228 | /// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32> |
| 229 | /// ``` |
| 230 | /// |
| 231 | /// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR |
| 232 | /// is, |
| 233 | /// |
| 234 | /// ```mlir |
| 235 | /// %t1_0 = scf.for .... iter_args(%arg0 = ...) { |
| 236 | /// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) { |
| 237 | /// ... |
| 238 | /// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
| 239 | /// %t1_3 = linalg.matmul ins(%t1_2, ...) |
| 240 | /// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ... |
| 241 | /// scf.yield %t1_4 |
| 242 | /// } |
| 243 | /// scf.yield %t1_1 |
| 244 | /// } |
| 245 | /// ``` |
| 246 | /// |
| 247 | /// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead |
| 248 | /// if `%1` were tiled only along the rows, the resultant code would be |
| 249 | /// |
| 250 | /// ```mlir |
| 251 | /// %t2_0 = scf.for .... iter_args(%arg0 = ...) { |
| 252 | /// ... |
| 253 | /// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
| 254 | /// %t2_2 = linalg.matmul ins(%t2_1, ...) |
| 255 | /// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ... |
| 256 | /// scf.yield %t2_3 |
| 257 | /// } |
| 258 | /// ``` |
| 259 | /// |
| 260 | /// Here there is no intersection in the different slices of `%t2_1` computed |
| 261 | /// across iterations of the `scf.for`. In such cases, the value of the original |
| 262 | /// `%0` can be reconstructed from within the loop body. This is useful in cases |
| 263 | /// where `%0` had other uses as well. If not reconstructed from within the loop |
| 264 | /// body, uses of `%0` could not be replaced, making it still live and the |
| 265 | /// fusion immaterial. |
| 266 | /// |
| 267 | /// The @param `yieldResultNumber` decides which result would be yield. If not |
| 268 | /// given, yield all `opResult` of fused producer. |
| 269 | /// |
| 270 | /// The method returns the list of new slices added during the process (which |
| 271 | /// can be used to fuse along). |
| 272 | FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer( |
| 273 | RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, |
| 274 | scf::SCFFuseProducerOfSliceResult fusedProducerInfo, |
| 275 | MutableArrayRef<LoopLikeOpInterface> loops, |
| 276 | ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{}); |
| 277 | |
| 278 | /// Transformation information returned after tile and fuse. |
| 279 | struct SCFTileAndFuseResult { |
| 280 | /// List of untiled operations that were fused with the tiled consumer. |
| 281 | llvm::SetVector<Operation *> fusedProducers; |
| 282 | /// List of tiled and fused operations generated. The first one in this list |
| 283 | /// is guaranteed to be the tiled operations generated during tiling of the |
| 284 | /// generated operation. |
| 285 | llvm::SetVector<Operation *> tiledAndFusedOps; |
| 286 | /// The `scf.for` operations that iterate over the tiles. |
| 287 | SmallVector<LoopLikeOpInterface> loops; |
| 288 | /// The replacement values to use for the tiled and fused operations. |
| 289 | llvm::DenseMap<Value, Value> replacements; |
| 290 | }; |
| 291 | |
| 292 | /// Method to tile and fuse a sequence of operations, by tiling the consumer |
| 293 | /// and fusing its producers. Note that this assumes that it is valid to |
| 294 | /// tile+fuse the producer into the innermost tiled loop. Its up to the caller |
| 295 | /// to ensure that the tile sizes provided make this fusion valid. |
| 296 | /// |
| 297 | /// For example, for the following sequence |
| 298 | /// |
| 299 | /// ```mlir |
| 300 | /// %0 = |
| 301 | /// %1 = linalg.fill ... outs(%0 : ... ) |
| 302 | /// %2 = linalg.matmul ... outs(%1 : ...) ... |
| 303 | /// ``` |
| 304 | /// |
| 305 | /// it is legal to fuse the fill with the matmul only if the matmul is tiled |
| 306 | /// along the parallel dimensions and not the reduction dimension, i.e. the tile |
| 307 | /// size for the reduction dimension should be 0. The resulting fused |
| 308 | /// transformation is |
| 309 | /// |
| 310 | /// ```mlir |
| 311 | /// %1 = scf.for ... iter_args(%arg0 = %0) |
| 312 | /// %2 = tensor.extract_slice %arg0 |
| 313 | /// %3 = linalg.fill .. outs(%2 : ... ) |
| 314 | /// %4 = linalg.matmul .. outs(%3 : ...) |
| 315 | /// } |
| 316 | /// ``` |
| 317 | FailureOr<SCFTileAndFuseResult> |
| 318 | tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, |
| 319 | TilingInterface consumer, |
| 320 | const SCFTileAndFuseOptions &options); |
| 321 | |
| 322 | /// Fuse the consumer of the source of `candidateSliceOp` by computing the |
| 323 | /// required slice of the consumer in-place. Note that the method |
| 324 | /// replaces the uses of `candidateSliceOp` with the tiled and fused consumer |
| 325 | /// value but does not delete the slice operation. |
| 326 | struct SCFFuseConsumerOfSliceResult { |
| 327 | OpOperand *origConsumerOperand; // Original untiled consumer's operand. |
| 328 | OpOperand |
| 329 | *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand. |
| 330 | SmallVector<Operation *> tiledOps; |
| 331 | }; |
| 332 | FailureOr<scf::SCFFuseConsumerOfSliceResult> |
| 333 | tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, |
| 334 | MutableArrayRef<LoopLikeOpInterface> loops); |
| 335 | |
| 336 | /// Method to lower an `op` that implements the `TilingInterface` to |
| 337 | /// loops/scalars. |
| 338 | FailureOr<SmallVector<scf::ForOp>> |
| 339 | lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); |
| 340 | |
| 341 | /// Method to tile a reduction and generate a parallel op within a serial loop. |
| 342 | /// Each of the partial reductions are calculated in parallel. Then after the |
| 343 | /// loop all the partial reduction are merged into a final reduction. |
| 344 | /// For example for the following sequence |
| 345 | /// |
| 346 | /// ```mlir |
| 347 | /// %0 = linalg.generic %in ["parallel", "reduction"] |
| 348 | /// : tensor<7x9xf32> -> tensor<7xf32> |
| 349 | /// ``` |
| 350 | /// |
| 351 | /// into: |
| 352 | /// |
| 353 | /// ```mlir |
| 354 | /// %0 = linalg.fill ... : tensor<7x4xf32> |
| 355 | /// %1 = scf.for ... iter_args(%arg0 = %0) |
| 356 | /// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32> |
| 357 | /// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> |
| 358 | /// %4 = linalg.generic %2, %3 ["parallel", "parallel"] |
| 359 | /// : tensor<7x?xf32> -> tensor<7x?xf32> |
| 360 | /// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32> |
| 361 | /// } |
| 362 | /// %6 = linalg.generic %1 ["parallel", "reduction"] |
| 363 | /// : tensor<7x4xf32> -> tensor<7xf32> |
| 364 | /// ``` |
| 365 | FailureOr<scf::SCFTilingResult> |
| 366 | tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, |
| 367 | ArrayRef<OpFoldResult> tileSizes); |
| 368 | |
| 369 | } // namespace scf |
| 370 | } // namespace mlir |
| 371 | |
| 372 | #endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
| 373 | |