| 1 | //===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
| 10 | #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
| 11 | |
| 12 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 13 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| 14 | #include "mlir/IR/PatternMatch.h" |
| 15 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 16 | #include "mlir/Interfaces/TilingInterface.h" |
| 17 | #include "mlir/Interfaces/ViewLikeInterface.h" |
| 18 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 19 | |
| 20 | #include <deque> |
| 21 | |
| 22 | namespace mlir { |
| 23 | class Operation; |
| 24 | class RewriterBase; |
| 25 | class TilingInterface; |
| 26 | } // namespace mlir |
| 27 | |
| 28 | namespace mlir { |
| 29 | namespace scf { |
| 30 | |
| 31 | using SCFTileSizeComputationFunction = |
| 32 | std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>; |
| 33 | |
| 34 | /// Options to use to control tiling. |
| 35 | struct SCFTilingOptions { |
| 36 | /// Computation function that returns the tile sizes to use for each loop. |
| 37 | /// Returning a tile size of zero implies no tiling for that loop. If the |
| 38 | /// size of the returned vector is smaller than the number of loops, the inner |
| 39 | /// loops are not tiled. If the size of the returned vector is larger, then |
| 40 | /// the vector is truncated to number of loops. |
| 41 | SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr; |
| 42 | |
| 43 | SCFTilingOptions & |
| 44 | setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) { |
| 45 | tileSizeComputationFunction = std::move(fun); |
| 46 | return *this; |
| 47 | } |
| 48 | /// Convenience function to set the `tileSizeComputationFunction` to a |
| 49 | /// function that computes tile sizes at the point they are needed. Allows |
| 50 | /// proper interaction with folding. |
| 51 | SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes); |
| 52 | |
| 53 | /// Computation function that returns the number of threads to use for |
| 54 | /// each loop. Returning a num threads of zero implies no tiling for that |
| 55 | /// loop. If the size of the returned vector is smaller than the number of |
| 56 | /// loops, the inner loops are not tiled. If the size of the returned vector |
| 57 | /// is larger, then the vector is truncated to number of loops. Note: This |
| 58 | /// option is only supported with loopType set to `LoopType::ForallOp`. If the |
| 59 | /// tile size function is not specified while the num threads computation is, |
| 60 | /// then the tile size is determined automatically to map at most one tile per |
| 61 | /// thread. |
| 62 | SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr; |
| 63 | |
| 64 | SCFTilingOptions & |
| 65 | setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) { |
| 66 | numThreadsComputationFunction = std::move(fun); |
| 67 | return *this; |
| 68 | } |
| 69 | /// Convenience function to set the `numThreadsComputationFunction` to a |
| 70 | /// function that computes num threads at the point they are needed. |
| 71 | SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads); |
| 72 | |
| 73 | /// The interchange vector to reorder the tiled loops. |
| 74 | SmallVector<int64_t> interchangeVector = {}; |
| 75 | SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) { |
| 76 | interchangeVector = llvm::to_vector(Range&: interchange); |
| 77 | return *this; |
| 78 | } |
| 79 | |
| 80 | /// Specify which loop construct to use for tile and fuse. |
| 81 | enum class LoopType { ForOp, ForallOp }; |
| 82 | LoopType loopType = LoopType::ForOp; |
| 83 | SCFTilingOptions &setLoopType(LoopType type) { |
| 84 | loopType = type; |
| 85 | return *this; |
| 86 | } |
| 87 | |
| 88 | /// Specify mapping of loops to devices. This is only respected when the loop |
| 89 | /// constructs support such a mapping (like `scf.forall`). Will be ignored |
| 90 | /// when using loop constructs that dont support such a mapping (like |
| 91 | /// `scf.for`) |
| 92 | SmallVector<Attribute> mappingVector = {}; |
| 93 | SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) { |
| 94 | mappingVector = llvm::to_vector(Range&: mapping); |
| 95 | return *this; |
| 96 | } |
| 97 | |
| 98 | //-------------------------------------------------------------------------// |
| 99 | // Options related reduction tiling |
| 100 | //-------------------------------------------------------------------------// |
| 101 | |
| 102 | /// Specify how reduction dimensions should be tiled. |
| 103 | ReductionTilingStrategy reductionStrategy = |
| 104 | ReductionTilingStrategy::FullReduction; |
| 105 | SCFTilingOptions & |
| 106 | setReductionTilingStrategy(ReductionTilingStrategy strategy) { |
| 107 | reductionStrategy = strategy; |
| 108 | return *this; |
| 109 | } |
| 110 | |
| 111 | /// Specify the reduction dimensions to be tiled. Note that this needs to be |
| 112 | /// specified. If left unspecified, then none of the reduction dimensions are |
| 113 | /// tiled. |
| 114 | SetVector<unsigned> reductionDims; |
| 115 | SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) { |
| 116 | reductionDims.clear(); |
| 117 | reductionDims.insert(Start: dims.begin(), End: dims.end()); |
| 118 | return *this; |
| 119 | } |
| 120 | }; |
| 121 | |
| 122 | /// Transformation information returned after tiling. |
| 123 | struct SCFTilingResult { |
| 124 | /// Tiled operations that are generated during tiling. The order does not |
| 125 | /// matter except the last op. The replacements are expected to be the results |
| 126 | /// of the last op. |
| 127 | SmallVector<Operation *> tiledOps; |
| 128 | /// The initial destination values passed to the tiled operations. |
| 129 | SmallVector<Value> initialValues; |
| 130 | /// The `scf.for` operations that iterate over the tiles. |
| 131 | SmallVector<LoopLikeOpInterface> loops; |
| 132 | /// Values to use as replacements for the untiled op. Is the same size as the |
| 133 | /// number of results of the untiled op. |
| 134 | SmallVector<Value> replacements; |
| 135 | /// Slices generated after tiling that can be used for fusing with the tiled |
| 136 | /// producer. |
| 137 | SmallVector<Operation *> generatedSlices; |
| 138 | /// In cases where there as an additional merge step after tiling |
| 139 | /// return the merged ops after tiling. This list is empty when reduction |
| 140 | /// tiling strategy is |
| 141 | /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction. |
| 142 | SmallVector<Operation *> mergeOps; |
| 143 | }; |
| 144 | |
| 145 | /// Method to tile an op that implements the `TilingInterface` using |
| 146 | /// `scf.for` for iterating over the tiles. |
| 147 | FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter, |
| 148 | TilingInterface op, |
| 149 | const SCFTilingOptions &options); |
| 150 | |
| 151 | /// Options used to control tile + fuse. |
| 152 | struct SCFTileAndFuseOptions { |
| 153 | /// The tiling options used to control the tiling of the consumer. |
| 154 | SCFTilingOptions tilingOptions; |
| 155 | SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) { |
| 156 | tilingOptions = options; |
| 157 | return *this; |
| 158 | } |
| 159 | |
| 160 | /// Control function to check if a slice needs to be fused or not, |
| 161 | /// The control function receives |
| 162 | /// 1) the slice along which fusion is to be done, |
| 163 | /// 2) the producer value that is to be fused |
| 164 | /// 3) a boolean value set to `true` if the fusion is from |
| 165 | /// a destination operand. |
| 166 | /// The control function returns an `std::optiona<ControlFnResult>`. |
| 167 | /// If the return value is `std::nullopt`, that implies no fusion |
| 168 | /// is to be performed along that slice. |
| 169 | struct ControlFnResult { |
| 170 | /// Set to true if the loop nest has to return a replacement value |
| 171 | /// for the fused producer. |
| 172 | bool yieldProducerReplacement = false; |
| 173 | }; |
| 174 | using ControlFnTy = std::function<std::optional<ControlFnResult>( |
| 175 | tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, |
| 176 | bool isDestinationOperand)>; |
| 177 | /// The default control function implements greedy fusion without yielding |
| 178 | /// a replacement for any of the fused results. |
| 179 | ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, |
| 180 | bool) -> std::optional<ControlFnResult> { |
| 181 | return ControlFnResult{}; |
| 182 | }; |
| 183 | SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) { |
| 184 | fusionControlFn = controlFn; |
| 185 | return *this; |
| 186 | } |
| 187 | |
| 188 | /// An optional set of rewrite patterns to apply to the results of tiling |
| 189 | /// before fusion. This will track deleted and newly inserted |
| 190 | /// `tensor.extract_slice` ops and update the worklist. |
| 191 | std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt; |
| 192 | }; |
| 193 | |
| 194 | /// Fuse the producer of the source of `candidateSliceOp` by computing the |
| 195 | /// required slice of the producer in-place. Note that the method |
| 196 | /// replaces the uses of `candidateSliceOp` with the tiled and fused producer |
| 197 | /// value but does not delete the slice operation. |
| 198 | struct SCFFuseProducerOfSliceResult { |
| 199 | OpResult origProducer; // Original untiled producer. |
| 200 | Value tiledAndFusedProducer; // Tile and fused producer value. |
| 201 | SmallVector<Operation *> tiledOps; |
| 202 | SmallVector<Operation *> generatedSlices; |
| 203 | }; |
| 204 | std::optional<SCFFuseProducerOfSliceResult> |
| 205 | tileAndFuseProducerOfSlice(RewriterBase &rewriter, |
| 206 | tensor::ExtractSliceOp candidateSliceOp, |
| 207 | MutableArrayRef<LoopLikeOpInterface> loops); |
| 208 | |
| 209 | /// Reconstruct the fused producer from within the tiled-and-fused code. Based |
| 210 | /// on the slice of the producer computed in place it is possible that within |
| 211 | /// the loop nest same slice of the producer is computed multiple times. It is |
| 212 | /// in general not possible to recompute the value of the fused producer from |
| 213 | /// the tiled loop code in such cases. For the cases where no slice of the |
| 214 | /// producer is computed in a redundant fashion it is possible to reconstruct |
| 215 | /// the value of the original producer from within the tiled loop. It is upto |
| 216 | /// the caller to ensure that the producer is not computed redundantly within |
| 217 | /// the tiled loop nest. For example, consider |
| 218 | /// |
| 219 | /// ```mlir |
| 220 | /// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
| 221 | /// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32> |
| 222 | /// ``` |
| 223 | /// |
| 224 | /// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR |
| 225 | /// is, |
| 226 | /// |
| 227 | /// ```mlir |
| 228 | /// %t1_0 = scf.for .... iter_args(%arg0 = ...) { |
| 229 | /// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) { |
| 230 | /// ... |
| 231 | /// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
| 232 | /// %t1_3 = linalg.matmul ins(%t1_2, ...) |
| 233 | /// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ... |
| 234 | /// scf.yield %t1_4 |
| 235 | /// } |
| 236 | /// scf.yield %t1_1 |
| 237 | /// } |
| 238 | /// ``` |
| 239 | /// |
| 240 | /// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead |
| 241 | /// if `%1` were tiled only along the rows, the resultant code would be |
| 242 | /// |
| 243 | /// ```mlir |
| 244 | /// %t2_0 = scf.for .... iter_args(%arg0 = ...) { |
| 245 | /// ... |
| 246 | /// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> |
| 247 | /// %t2_2 = linalg.matmul ins(%t2_1, ...) |
| 248 | /// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ... |
| 249 | /// scf.yield %t2_3 |
| 250 | /// } |
| 251 | /// ``` |
| 252 | /// |
| 253 | /// Here there is no intersection in the different slices of `%t2_1` computed |
| 254 | /// across iterations of the `scf.for`. In such cases, the value of the original |
| 255 | /// `%0` can be reconstructed from within the loop body. This is useful in cases |
| 256 | /// where `%0` had other uses as well. If not reconstructed from within the loop |
| 257 | /// body, uses of `%0` could not be replaced, making it still live and the |
| 258 | /// fusion immaterial. |
| 259 | /// |
| 260 | /// The @param `yieldResultNumber` decides which result would be yield. If not |
| 261 | /// given, yield all `opResult` of fused producer. |
| 262 | /// |
| 263 | /// The method returns the list of new slices added during the process (which |
| 264 | /// can be used to fuse along). |
| 265 | FailureOr<SmallVector<Operation *>> ( |
| 266 | RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, |
| 267 | scf::SCFFuseProducerOfSliceResult fusedProducerInfo, |
| 268 | MutableArrayRef<LoopLikeOpInterface> loops, |
| 269 | ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{}); |
| 270 | |
| 271 | /// Transformation information returned after tile and fuse. |
| 272 | struct SCFTileAndFuseResult { |
| 273 | /// List of untiled operations that were fused with the tiled consumer. |
| 274 | llvm::SetVector<Operation *> fusedProducers; |
| 275 | /// List of tiled and fused operations generated. The first element is always |
| 276 | /// the tiled version of the original consumer operation processed by |
| 277 | /// `tileConsumerAndFuseProducersUsingSCF`, followed by any operations that |
| 278 | /// were fused with it. |
| 279 | llvm::SetVector<Operation *> tiledAndFusedOps; |
| 280 | /// The `scf.for` operations that iterate over the tiles. |
| 281 | SmallVector<LoopLikeOpInterface> loops; |
| 282 | /// The replacement values to use for the tiled and fused operations. |
| 283 | llvm::DenseMap<Value, Value> replacements; |
| 284 | }; |
| 285 | |
| 286 | /// Method to tile and fuse a sequence of operations, by tiling the consumer |
| 287 | /// and fusing its producers. Note that this assumes that it is valid to |
| 288 | /// tile+fuse the producer into the innermost tiled loop. Its up to the caller |
| 289 | /// to ensure that the tile sizes provided make this fusion valid. |
| 290 | /// |
| 291 | /// For example, for the following sequence |
| 292 | /// |
| 293 | /// ```mlir |
| 294 | /// %0 = |
| 295 | /// %1 = linalg.fill ... outs(%0 : ... ) |
| 296 | /// %2 = linalg.matmul ... outs(%1 : ...) ... |
| 297 | /// ``` |
| 298 | /// |
| 299 | /// it is legal to fuse the fill with the matmul only if the matmul is tiled |
| 300 | /// along the parallel dimensions and not the reduction dimension, i.e. the tile |
| 301 | /// size for the reduction dimension should be 0. The resulting fused |
| 302 | /// transformation is |
| 303 | /// |
| 304 | /// ```mlir |
| 305 | /// %1 = scf.for ... iter_args(%arg0 = %0) |
| 306 | /// %2 = tensor.extract_slice %arg0 |
| 307 | /// %3 = linalg.fill .. outs(%2 : ... ) |
| 308 | /// %4 = linalg.matmul .. outs(%3 : ...) |
| 309 | /// } |
| 310 | /// ``` |
| 311 | FailureOr<SCFTileAndFuseResult> |
| 312 | tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, |
| 313 | TilingInterface consumer, |
| 314 | const SCFTileAndFuseOptions &options); |
| 315 | |
| 316 | /// Fuse the consumer `candidateSlices` by computing the required slice of the |
| 317 | /// consumer in-place. All the entries of `candidateSlices` are expected to map |
| 318 | /// to the same consumer. The method returns an error if the consumer cannot be |
| 319 | /// tiled in a manner that is consistent for all the passed slices. Note that |
| 320 | /// the method replaces the uses of `candidateSlices` with the tiled and fused |
| 321 | /// consumer value but does not delete the slice operations. |
| 322 | struct SCFFuseConsumerOfSliceResult { |
| 323 | // Original untiled consumer operands. |
| 324 | SmallVector<OpOperand *> origConsumerOperands; |
| 325 | // Tiled and fused consumer operands. |
| 326 | SmallVector<OpOperand *> tiledAndFusedConsumerOperands; |
| 327 | SmallVector<Operation *> tiledOps; |
| 328 | }; |
| 329 | FailureOr<scf::SCFFuseConsumerOfSliceResult> |
| 330 | tileAndFuseConsumerOfSlices(RewriterBase &rewriter, |
| 331 | ArrayRef<Operation *> candidateSlices, |
| 332 | MutableArrayRef<LoopLikeOpInterface> loops); |
| 333 | |
| 334 | /// Method to lower an `op` that implements the `TilingInterface` to |
| 335 | /// loops/scalars. |
| 336 | FailureOr<SmallVector<scf::ForOp>> |
| 337 | lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); |
| 338 | |
| 339 | /// Method to tile a reduction and generate a parallel op within a serial loop. |
| 340 | /// Each of the partial reductions are calculated in parallel. Then after the |
| 341 | /// loop all the partial reduction are merged into a final reduction. |
| 342 | /// For example for the following sequence |
| 343 | /// |
| 344 | /// ```mlir |
| 345 | /// %0 = linalg.generic %in ["parallel", "reduction"] |
| 346 | /// : tensor<7x9xf32> -> tensor<7xf32> |
| 347 | /// ``` |
| 348 | /// |
| 349 | /// into: |
| 350 | /// |
| 351 | /// ```mlir |
| 352 | /// %0 = linalg.fill ... : tensor<7x4xf32> |
| 353 | /// %1 = scf.for ... iter_args(%arg0 = %0) |
| 354 | /// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32> |
| 355 | /// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> |
| 356 | /// %4 = linalg.generic %2, %3 ["parallel", "parallel"] |
| 357 | /// : tensor<7x?xf32> -> tensor<7x?xf32> |
| 358 | /// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32> |
| 359 | /// } |
| 360 | /// %6 = linalg.generic %1 ["parallel", "reduction"] |
| 361 | /// : tensor<7x4xf32> -> tensor<7xf32> |
| 362 | /// ``` |
| 363 | FailureOr<scf::SCFTilingResult> |
| 364 | tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, |
| 365 | ArrayRef<OpFoldResult> tileSizes); |
| 366 | |
| 367 | } // namespace scf |
| 368 | } // namespace mlir |
| 369 | |
| 370 | #endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H |
| 371 | |