| 1 | //===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H |
| 10 | #define MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H |
| 11 | |
| 12 | #include <utility> |
| 13 | |
| 14 | #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
| 15 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 16 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 18 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 21 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 22 | #include "mlir/Dialect/X86Vector/Transforms.h" |
| 23 | #include "mlir/IR/PatternMatch.h" |
| 24 | #include "mlir/Interfaces/TilingInterface.h" |
| 25 | #include "mlir/Transforms/DialectConversion.h" |
| 26 | #include "llvm/ADT/SmallBitVector.h" |
| 27 | #include "llvm/ADT/SmallSet.h" |
| 28 | |
| 29 | namespace mlir { |
| 30 | namespace bufferization { |
| 31 | class AllocTensorOp; |
| 32 | class OneShotAnalysisState; |
| 33 | class BufferizationState; |
| 34 | } // namespace bufferization |
| 35 | |
| 36 | namespace linalg { |
| 37 | |
| 38 | class LinalgOp; |
| 39 | |
| 40 | //===----------------------------------------------------------------------===// |
| 41 | // Utils. |
| 42 | //===----------------------------------------------------------------------===// |
| 43 | |
| 44 | /// Return vector::CombiningKind for the given op. |
| 45 | std::optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp); |
| 46 | |
| 47 | //===----------------------------------------------------------------------===// |
| 48 | // Bufferization-related transforms. |
| 49 | //===----------------------------------------------------------------------===// |
| 50 | |
| 51 | struct BufferizeToAllocationOptions { |
| 52 | enum class AllocOp { MemrefAlloc = 0, MemrefAlloca = 1 }; |
| 53 | AllocOp allocOp = AllocOp::MemrefAlloc; |
| 54 | |
| 55 | enum class MemcpyOp { |
| 56 | MaterializeInDestination = 0, |
| 57 | MemrefCopy = 1, |
| 58 | LinalgCopy = 2 |
| 59 | }; |
| 60 | MemcpyOp memcpyOp = MemcpyOp::MaterializeInDestination; |
| 61 | |
| 62 | /// If set to "true", only the destination tensor operands are bufferized to |
| 63 | /// a new allocation (and wrapped in "bufferization.to_tensor"), but not the |
| 64 | /// targeted op itself. |
| 65 | bool bufferizeDestinationOnly = false; |
| 66 | |
| 67 | /// If set to "true", a `memref.dealloc` operation will be emitted for each |
| 68 | /// allocated buffer. Otherwise, the memory is leaked, which is useful if |
| 69 | /// the buffer deallocation pipeline should be run after bufferization is |
| 70 | /// done. |
| 71 | bool emitDealloc = false; |
| 72 | }; |
| 73 | |
| 74 | /// Materialize a buffer allocation for the given tensor.pad op and lower the |
| 75 | /// op to linalg.fill/linalg.generic + bufferization.materialize_in_destination. |
| 76 | /// E.g.: |
| 77 | /// |
| 78 | /// %0 = tensor.pad low[%l] high[%h] %t ... |
| 79 | /// |
| 80 | /// is lowered to: |
| 81 | /// |
| 82 | /// %alloc = memref.alloc |
| 83 | /// linalg.fill ... outs(%alloc) |
| 84 | /// %subview = memref.subview %alloc [%l] [...] [1] |
| 85 | /// bufferization.materialize_in_destination %t in %subview |
| 86 | /// %0 = bufferization.to_tensor %alloc restrict writable |
| 87 | /// |
| 88 | /// In addition to rewriting the IR as shown above, this function returns the |
| 89 | /// newly allocated buffer. The `insertionPoint` parameter can be used to |
| 90 | /// specify a custom insertion point for the buffer allocation. |
| 91 | Value bufferizeToAllocation(RewriterBase &rewriter, |
| 92 | const BufferizeToAllocationOptions &options, |
| 93 | tensor::PadOp padOp, Attribute memorySpace = {}, |
| 94 | Operation *insertionPoint = nullptr); |
| 95 | |
| 96 | /// Materialize a buffer allocation for the given vector.mask op and bufferize |
| 97 | /// the op, including its region. E.g.: |
| 98 | /// |
| 99 | /// %0 = vector.mask { |
| 100 | /// vector.transfer_write %v, %t : vector<16xf32>, tensor<?xf32> |
| 101 | /// } : vector<16xi1> -> tensor<?xf32> |
| 102 | /// |
| 103 | /// is lowered to: |
| 104 | /// |
| 105 | /// %alloc = memref.alloc |
| 106 | /// bufferization.materialize_in_destination %t in %subview |
| 107 | /// vector.mask { |
| 108 | /// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32> |
| 109 | /// } : vector<16xi1> |
| 110 | /// %0 = bufferization.to_tensor %alloc restrict writable |
| 111 | /// |
| 112 | /// In addition to rewriting the IR as shown above, this function returns the |
| 113 | /// newly allocated buffer. The `insertionPoint` parameter can be used to |
| 114 | /// specify a custom insertion point for the buffer allocation. |
| 115 | Value bufferizeToAllocation(RewriterBase &rewriter, |
| 116 | const BufferizeToAllocationOptions &options, |
| 117 | vector::MaskOp maskOp, Attribute memorySpace = {}, |
| 118 | Operation *insertionPoint = nullptr); |
| 119 | |
| 120 | /// Materialize a buffer allocation for the given bufferization.alloc_tensor op |
| 121 | /// and lower the op to memref.alloc + memref.tensor_store. |
| 122 | /// |
| 123 | /// In addition to rewriting the IR, this function returns the newly allocated |
| 124 | /// buffer. The `insertionPoint` parameter can be used to specify a custom |
| 125 | /// insertion point for the buffer allocation. |
| 126 | Value bufferizeToAllocation(RewriterBase &rewriter, |
| 127 | const BufferizeToAllocationOptions &options, |
| 128 | bufferization::AllocTensorOp allocTensorOp, |
| 129 | Attribute memorySpace = {}, |
| 130 | Operation *insertionPoint = nullptr); |
| 131 | |
| 132 | /// Bufferize the given op with tensor semantics and materialize the result in |
| 133 | /// a newly allocated buffer. |
| 134 | /// |
| 135 | /// Only bufferizable ops that bufferize to a memory write or have an |
| 136 | /// aliasing OpOperand (and do not themselves bufferize to an allocation) are |
| 137 | /// supported. They are bufferized using their BufferizableOpInterface |
| 138 | /// implementation. |
| 139 | /// |
| 140 | /// Selected ops that bufferize to an allocation (or need special handling) are |
| 141 | /// also supported: |
| 142 | /// - tensor.pad |
| 143 | /// - vector.mask |
| 144 | /// |
| 145 | /// This function returns the newly allocated buffer. The `insertionPoint` |
| 146 | /// parameter can be used to specify a custom insertion point for the buffer |
| 147 | /// allocation. |
| 148 | Value bufferizeToAllocation(RewriterBase &rewriter, |
| 149 | const BufferizeToAllocationOptions &options, |
| 150 | Operation *op, Attribute memorySpace = {}, |
| 151 | Operation *insertionPoint = nullptr); |
| 152 | |
| 153 | /// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a |
| 154 | /// LinalgOp. This transforms looks for LinalgOps that have an unused output |
| 155 | /// operand and an input operand that is rooted in a tensor::EmptyOp. The |
| 156 | /// tensor::EmptyOp uses are replaced with the output operand and the two |
| 157 | /// operands of the LinalgOp are swapped. |
| 158 | /// |
| 159 | /// Example: |
| 160 | /// %0 = tensor.empty() |
| 161 | /// %1 = linalg.matmul ins(...) outs(%0) |
| 162 | /// %2 = linalg.generic ins(%1) outs(%dest) { |
| 163 | /// ^bb0(%in: f32, %out: f32): |
| 164 | /// // out not used |
| 165 | /// } |
| 166 | /// |
| 167 | /// The IR is transformed as follows: |
| 168 | /// %0 = tensor.empty() |
| 169 | /// %1 = linalg.matmul ins(...) outs(%dest) |
| 170 | /// %2 = linalg.generic ins(%0) outs(%1) { |
| 171 | /// ^bb0(%in: f32, %out: f32): |
| 172 | /// // Use %out instead of %in |
| 173 | /// } |
| 174 | /// |
| 175 | /// The "ins" operand has no uses inside the body of the LinalgOp and can be |
| 176 | /// folded away with existing cleanup patterns. Afterwards, the tensor::EmptyOp |
| 177 | /// can also fold away. |
| 178 | LogicalResult linalgOpAnchoredEmptyTensorEliminationStep( |
| 179 | RewriterBase &rewriter, Operation *op, |
| 180 | bufferization::OneShotAnalysisState &state); |
| 181 | |
| 182 | //===----------------------------------------------------------------------===// |
| 183 | // Structs that configure the behavior of various transformations. |
| 184 | //===----------------------------------------------------------------------===// |
| 185 | |
| 186 | using TileSizeComputationFunction = |
| 187 | std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>; |
| 188 | |
| 189 | struct LinalgTilingOptions { |
| 190 | /// Computation function that returns the tile sizes for each operation. |
| 191 | /// Delayed construction of constant tile sizes should occur to interoperate |
| 192 | /// with folding. |
| 193 | TileSizeComputationFunction tileSizeComputationFunction = nullptr; |
| 194 | |
| 195 | LinalgTilingOptions & |
| 196 | setTileSizeComputationFunction(TileSizeComputationFunction fun) { |
| 197 | tileSizeComputationFunction = std::move(fun); |
| 198 | return *this; |
| 199 | } |
| 200 | /// Set the `tileSizeComputationFunction` to return the values `ts`. The |
| 201 | /// values must not fold away when tiling. Otherwise, use a more robust |
| 202 | /// `tileSizeComputationFunction`. |
| 203 | LinalgTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) { |
| 204 | tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; |
| 205 | return *this; |
| 206 | } |
| 207 | /// Convenience function to set the `tileSizeComputationFunction` to a |
| 208 | /// function that computes tile sizes at the point they are needed. Allows |
| 209 | /// proper interaction with folding. |
| 210 | LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts); |
| 211 | |
| 212 | /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions. |
| 213 | /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together. |
| 214 | LinalgTilingOptions &scalarizeDynamicDims(); |
| 215 | |
| 216 | /// The interchange vector to reorder the tiled loops. |
| 217 | SmallVector<unsigned, 4> interchangeVector = {}; |
| 218 | |
| 219 | LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) { |
| 220 | interchangeVector.assign(in_start: interchange.begin(), in_end: interchange.end()); |
| 221 | return *this; |
| 222 | } |
| 223 | |
| 224 | /// The type of tile loops to generate. |
| 225 | LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops; |
| 226 | |
| 227 | LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) { |
| 228 | loopType = lt; |
| 229 | return *this; |
| 230 | } |
| 231 | |
| 232 | /// When specified, specifies distribution of generated tile loops to |
| 233 | /// processors. |
| 234 | std::optional<LinalgLoopDistributionOptions> distribution; |
| 235 | |
| 236 | LinalgTilingOptions & |
| 237 | setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { |
| 238 | distribution = std::move(distributionOptions); |
| 239 | return *this; |
| 240 | } |
| 241 | |
| 242 | /// Specification markers of how to distribute the `linalg.tiled_loop`. |
| 243 | SmallVector<StringRef, 2> distributionTypes = {}; |
| 244 | |
| 245 | LinalgTilingOptions &setDistributionTypes(ArrayRef<StringRef> types) { |
| 246 | distributionTypes.assign(in_start: types.begin(), in_end: types.end()); |
| 247 | return *this; |
| 248 | } |
| 249 | |
| 250 | /// Peel the specified loops. |
| 251 | SmallVector<int64_t> peeledLoops; |
| 252 | |
| 253 | LinalgTilingOptions &setPeeledLoops(ArrayRef<int64_t> loops) { |
| 254 | peeledLoops.clear(); |
| 255 | peeledLoops.append(in_start: loops.begin(), in_end: loops.end()); |
| 256 | return *this; |
| 257 | } |
| 258 | }; |
| 259 | |
| 260 | struct LinalgTilingAndFusionOptions { |
| 261 | /// Tile sizes used to tile the root operation. |
| 262 | SmallVector<int64_t> tileSizes; |
| 263 | LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) { |
| 264 | tileSizes.assign(in_start: ts.begin(), in_end: ts.end()); |
| 265 | return *this; |
| 266 | } |
| 267 | /// Tile interchange used to permute the tile loops. |
| 268 | SmallVector<int64_t> tileInterchange; |
| 269 | /// When specified, specifies distribution of generated tile loops to |
| 270 | /// processors. |
| 271 | std::optional<LinalgLoopDistributionOptions> tileDistribution; |
| 272 | LinalgTilingAndFusionOptions & |
| 273 | setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { |
| 274 | tileDistribution = std::move(distributionOptions); |
| 275 | return *this; |
| 276 | } |
| 277 | }; |
| 278 | |
| 279 | struct LinalgPaddingOptions { |
| 280 | /// A padding value for every operand. |
| 281 | SmallVector<Attribute> paddingValues; |
| 282 | LinalgPaddingOptions &setPaddingValues(ArrayRef<Attribute> pv) { |
| 283 | paddingValues.assign(in_start: pv.begin(), in_end: pv.end()); |
| 284 | return *this; |
| 285 | } |
| 286 | /// A list of iterator dimensions to pad. |
| 287 | SmallVector<int64_t> paddingDimensions; |
| 288 | LinalgPaddingOptions &setPaddingDimensions(ArrayRef<int64_t> pd) { |
| 289 | paddingDimensions.assign(in_start: pd.begin(), in_end: pd.end()); |
| 290 | return *this; |
| 291 | } |
| 292 | /// A list of multiples to which each padding dimension should be padded to. |
| 293 | std::optional<SmallVector<int64_t>> padToMultipleOf; |
| 294 | LinalgPaddingOptions &setPadToMultipleOf(ArrayRef<int64_t> m) { |
| 295 | padToMultipleOf.emplace(args: m.begin(), args: m.end()); |
| 296 | return *this; |
| 297 | } |
| 298 | /// A flag for every operand to mark the PadOp as nofold which enables |
| 299 | /// packing for statically shaped operands. |
| 300 | SmallVector<bool> nofoldFlags; |
| 301 | LinalgPaddingOptions &setNofoldFlags(ArrayRef<bool> pp) { |
| 302 | nofoldFlags.assign(in_start: pp.begin(), in_end: pp.end()); |
| 303 | return *this; |
| 304 | } |
| 305 | /// A number of loops to hoist the PadOp out for every operand. |
| 306 | SmallVector<int64_t> hoistPaddings; |
| 307 | LinalgPaddingOptions &setHoistPaddings(ArrayRef<int64_t> hp) { |
| 308 | hoistPaddings.assign(in_start: hp.begin(), in_end: hp.end()); |
| 309 | return *this; |
| 310 | } |
| 311 | /// A permutation vector for every operand used to transpose the packed |
| 312 | /// PadOp results. |
| 313 | SmallVector<SmallVector<int64_t>> transposePaddings; |
| 314 | LinalgPaddingOptions & |
| 315 | setTransposePaddings(ArrayRef<SmallVector<int64_t>> tp) { |
| 316 | transposePaddings.assign(in_start: tp.begin(), in_end: tp.end()); |
| 317 | return *this; |
| 318 | } |
| 319 | enum class CopyBackOp : int8_t { |
| 320 | None = 0, |
| 321 | BufferizationMaterializeInDestination = 1, |
| 322 | LinalgCopy = 2 |
| 323 | }; |
| 324 | /// The op to be used for copying the padded result to the original |
| 325 | /// destination tensor. |
| 326 | CopyBackOp copyBackOp = CopyBackOp::BufferizationMaterializeInDestination; |
| 327 | LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) { |
| 328 | copyBackOp = op; |
| 329 | return *this; |
| 330 | } |
| 331 | }; |
| 332 | |
| 333 | /// Callback function type used to perform the allocation for the promoted |
| 334 | /// `subView`. In `boundingSubViewsize` a best attempt is made to find the |
| 335 | /// smallest constant value for the size of the buffer needed for each |
| 336 | /// dimension. If that is not possible, contains the dynamic size of the |
| 337 | /// subview. The call back should return the buffer to use. |
| 338 | using AllocBufferCallbackFn = std::function<std::optional<Value>( |
| 339 | OpBuilder &b, memref::SubViewOp subView, |
| 340 | ArrayRef<Value> boundingSubViewSize, DataLayout &layout)>; |
| 341 | |
| 342 | /// Callback function type used to deallocate the buffers used to hold the |
| 343 | /// promoted subview. |
| 344 | using DeallocBufferCallbackFn = |
| 345 | std::function<LogicalResult(OpBuilder &b, Value buffer)>; |
| 346 | |
| 347 | /// Callback function type used to insert copy from original subview to |
| 348 | /// subview of the promoted region for the read operands/subview of promoted |
| 349 | /// region to original subview for the results. The copy has to happen from |
| 350 | /// `src` to `dst`. |
| 351 | using CopyCallbackFn = |
| 352 | std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>; |
| 353 | |
| 354 | struct LinalgPromotionOptions { |
| 355 | /// Indices of subViews to promote. If `std::nullopt`, try to promote all |
| 356 | /// operands. |
| 357 | std::optional<DenseSet<unsigned>> operandsToPromote; |
| 358 | LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) { |
| 359 | operandsToPromote = DenseSet<unsigned>(); |
| 360 | operandsToPromote->insert_range(R&: operands); |
| 361 | return *this; |
| 362 | } |
| 363 | /// If ith element of `useFullTiles` is true the full view should be used |
| 364 | /// for the promoted buffer of the ith operand in `operandsToPromote`. |
| 365 | /// Otherwise the partial view will be used. The decision is defaulted to |
| 366 | /// `useFullTileBuffersDefault` when `useFullTileBuffers` is std::nullopt and |
| 367 | /// for operands missing from `useFullTileBuffers`. |
| 368 | std::optional<llvm::SmallBitVector> useFullTileBuffers; |
| 369 | LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) { |
| 370 | unsigned size = useFullTiles.size(); |
| 371 | llvm::SmallBitVector tmp(size, false); |
| 372 | for (unsigned i = 0; i < size; ++i) |
| 373 | tmp[i] = useFullTiles[i]; |
| 374 | useFullTileBuffers = tmp; |
| 375 | return *this; |
| 376 | } |
| 377 | /// If true all operands unspecified by `useFullTileBuffers` will use the |
| 378 | /// full view, otherwise the partial view. |
| 379 | bool useFullTileBuffersDefault = false; |
| 380 | LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) { |
| 381 | useFullTileBuffersDefault = use; |
| 382 | return *this; |
| 383 | } |
| 384 | /// Alignment of promoted buffer. If `std::nullopt` do not specify alignment. |
| 385 | std::optional<unsigned> alignment; |
| 386 | LinalgPromotionOptions &setAlignment(unsigned align) { |
| 387 | alignment = align; |
| 388 | return *this; |
| 389 | } |
| 390 | /// Memory space of promoted buffer. If `std::nullopt` do not specify memory |
| 391 | /// space. |
| 392 | std::optional<Attribute> memorySpace; |
| 393 | LinalgPromotionOptions &setMemorySpace(Attribute memorySpc) { |
| 394 | memorySpace = memorySpc; |
| 395 | return *this; |
| 396 | } |
| 397 | /// Use alloca with the default allocation scheme. |
| 398 | bool useAlloca = false; |
| 399 | LinalgPromotionOptions &setUseAlloca(bool use) { |
| 400 | useAlloca = use; |
| 401 | return *this; |
| 402 | } |
| 403 | /// Callback function to do the allocation of the promoted buffer. If |
| 404 | /// std::nullopt, then the default allocation scheme of allocating a |
| 405 | /// memref<?xi8> buffer followed by a view operation is used. |
| 406 | std::optional<AllocBufferCallbackFn> allocationFn; |
| 407 | std::optional<DeallocBufferCallbackFn> deallocationFn; |
| 408 | LinalgPromotionOptions & |
| 409 | setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, |
| 410 | DeallocBufferCallbackFn const &deallocFn) { |
| 411 | allocationFn = allocFn; |
| 412 | deallocationFn = deallocFn; |
| 413 | return *this; |
| 414 | } |
| 415 | /// Callback function to do the copy of data to and from the promoted |
| 416 | /// subview. If std::nullopt then a memref.copy is used. |
| 417 | std::optional<CopyCallbackFn> copyInFn; |
| 418 | std::optional<CopyCallbackFn> copyOutFn; |
| 419 | LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const ©In, |
| 420 | CopyCallbackFn const ©Out) { |
| 421 | copyInFn = copyIn; |
| 422 | copyOutFn = copyOut; |
| 423 | return *this; |
| 424 | } |
| 425 | }; |
| 426 | |
| 427 | /// Split Reduction options. |
| 428 | struct SplitReductionOptions { |
| 429 | // Ratio used to split the reduction dimension. If the ratio is <= 1, |
| 430 | // nothing will be done. |
| 431 | int64_t ratio = 0; |
| 432 | // Index where the extra dimension is added to the intermediate tensor |
| 433 | // shape. |
| 434 | unsigned index = 0; |
| 435 | // If the inner dimension after splitting is parallel or reduction. |
| 436 | bool innerParallel = false; |
| 437 | }; |
| 438 | |
| 439 | /// Function signature to control reduction splitting. This returns |
| 440 | /// `SplitReductionOptions`. |
| 441 | // TODO: don't use unsigned unless doing bit manipulation. |
| 442 | using ControlSplitReductionFn = |
| 443 | std::function<SplitReductionOptions(LinalgOp op)>; |
| 444 | |
| 445 | //===----------------------------------------------------------------------===// |
| 446 | // Preconditions that ensure the corresponding transformation succeeds and can |
| 447 | // be applied as a rewrite pattern. |
| 448 | //===----------------------------------------------------------------------===// |
| 449 | |
| 450 | /// Return true if two `linalg.generic` operations with producer/consumer |
| 451 | /// relationship through `fusedOperand` can be fused using elementwise op |
| 452 | /// fusion. |
| 453 | bool areElementwiseOpsFusable(OpOperand *fusedOperand); |
| 454 | |
| 455 | /// Promote memref.subviews feeding linalg-on-buffers operations. |
| 456 | LogicalResult promoteSubviewsPrecondition(Operation *op, |
| 457 | LinalgPromotionOptions options); |
| 458 | |
| 459 | /// Return success if the operation can be vectorized. |
| 460 | LogicalResult vectorizeOpPrecondition(Operation *op, |
| 461 | ArrayRef<int64_t> inputVectorSizes = {}, |
| 462 | ArrayRef<bool> inputScalableVecDims = {}, |
| 463 | bool = false, |
| 464 | bool flatten1DDepthwiseConv = false); |
| 465 | |
| 466 | //===----------------------------------------------------------------------===// |
| 467 | // Transformations exposed as functional-style API calls. |
| 468 | //===----------------------------------------------------------------------===// |
| 469 | |
| 470 | using LinalgLoops = SmallVector<Operation *, 4>; |
| 471 | |
| 472 | /// Transformation to drop unit-extent dimensions from `linalg.generic` |
| 473 | /// operations. |
| 474 | struct ControlDropUnitDims { |
| 475 | enum class RankReductionStrategy { ReassociativeReshape, }; |
| 476 | |
| 477 | RankReductionStrategy rankReductionStrategy = |
| 478 | RankReductionStrategy::ReassociativeReshape; |
| 479 | |
| 480 | using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>; |
| 481 | ControlFnTy controlFn = [](Operation *op) { |
| 482 | if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) { |
| 483 | return llvm::to_vector(llvm::seq<unsigned>(0, genericOp.getNumLoops())); |
| 484 | } |
| 485 | if (auto padOp = dyn_cast_or_null<tensor::PadOp>(op)) { |
| 486 | return llvm::to_vector( |
| 487 | llvm::seq<unsigned>(0, padOp.getSourceType().getRank())); |
| 488 | } |
| 489 | return SmallVector<unsigned>{}; |
| 490 | }; |
| 491 | }; |
| 492 | struct DropUnitDimsResult { |
| 493 | linalg::GenericOp resultOp; |
| 494 | SmallVector<Value> replacements; |
| 495 | }; |
| 496 | FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter, |
| 497 | GenericOp genericOp, |
| 498 | const ControlDropUnitDims &options); |
| 499 | |
| 500 | /// Fuse two `linalg.generic` operations that have a producer-consumer |
| 501 | /// relationship captured through `fusedOperand`. The method expects |
| 502 | /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`. |
| 503 | struct ElementwiseOpFusionResult { |
| 504 | Operation *fusedOp; |
| 505 | llvm::DenseMap<Value, Value> replacements; |
| 506 | }; |
| 507 | FailureOr<ElementwiseOpFusionResult> |
| 508 | fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); |
| 509 | |
| 510 | /// Returns a set of indices of the producer's results which would |
| 511 | /// be preserved after the fusion. |
| 512 | /// * There is a chance that the implementation of the transformation does not |
| 513 | /// agree with the result of this method. This function gives a prediction based |
| 514 | /// on an optimized fusion. |
| 515 | llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer, |
| 516 | GenericOp consumer, |
| 517 | OpOperand *fusedOperand); |
| 518 | |
| 519 | /// Try to peel and canonicalize loop `op` and return the new result. |
| 520 | /// Also applies affine_min/max bounds simplification on the fly where relevant. |
| 521 | // TODO: Add support for scf.parallel and affine.for loops. |
| 522 | SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op); |
| 523 | |
| 524 | /// Peel 'loops' and applies affine_min/max bounds simplification on the fly |
| 525 | /// where relevant. |
| 526 | void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops); |
| 527 | |
| 528 | /// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands |
| 529 | /// to a static bounding box. The original `opToPad` is cloned and operates on |
| 530 | /// the padded tensors. |
| 531 | /// |
| 532 | /// * "options.padToMultipleOf" indicates that each padding dimension should be |
| 533 | /// padded to the specified multiple. |
| 534 | /// * Use "options.paddingValues" and "options.nofoldFlags" to set padding |
| 535 | /// value and nofold attribute of the created tensor::PadOps, respectively. |
| 536 | /// * The unpadded results (extracted slice of the cloned operation) are |
| 537 | /// returned via `replacements`. |
| 538 | /// * The tensor::PadOps are returned via `padOps`. |
| 539 | /// * "options.copyBackOp" specifies the op type for copying back the unpadded |
| 540 | /// result to the original destination tensor. |
| 541 | LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, |
| 542 | const LinalgPaddingOptions &options, |
| 543 | LinalgOp &paddedOp, |
| 544 | SmallVector<Value> &replacements, |
| 545 | SmallVector<tensor::PadOp> &padOps); |
| 546 | |
| 547 | namespace detail { |
| 548 | |
| 549 | /// Helper struct to hold the results of building a packing loop nest. |
| 550 | struct PackingResult { |
| 551 | SmallVector<OpFoldResult> offsets, sizes, strides; |
| 552 | SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings; |
| 553 | TransposeOp maybeTransposeOp; |
| 554 | tensor::PadOp hoistedPadOp; |
| 555 | }; |
| 556 | |
| 557 | /// Build the packing loop nest required to hoist `opToHoist` above |
| 558 | /// `outermostEnclosingForOp`. |
| 559 | /// The loop nest is built just before `outermostEnclosingForOp`. |
| 560 | FailureOr<PackingResult> |
| 561 | buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, |
| 562 | scf::ForOp outermostEnclosingForOp, |
| 563 | ArrayRef<int64_t> transposeVector); |
| 564 | |
| 565 | } // namespace detail |
| 566 | |
| 567 | /// Mechanically hoist padding operations on tensors by `numLoops` into a new, |
| 568 | /// generally larger tensor. This achieves packing of multiple padding ops into |
| 569 | /// a larger tensor. On success, `opToHoist` is replaced by the cloned version |
| 570 | /// in the packing loop so the caller can continue reasoning about the padding |
| 571 | /// operation. If `transposeVector` is non-empty, hoist padding introduces a |
| 572 | /// TransposeOp to transpose the padded tensor before inserting it into the |
| 573 | /// packed tensor. A `transposeVector` can change the storage order of the |
| 574 | /// padded tensor but does not change the order of the pack or compute loops. |
| 575 | /// |
| 576 | /// TODO: In the future, we should consider rewriting as a linalg.pack after |
| 577 | /// hoisting since this abstraction is now available. |
| 578 | /// |
| 579 | /// Example in pseudo-mlir: |
| 580 | /// ======================= |
| 581 | /// |
| 582 | /// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR. |
| 583 | /// ``` |
| 584 | /// scf.for (%i, %j, %k) |
| 585 | /// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32> |
| 586 | /// %0 = tensor.pad %st0 low[0, 0] high[...] { |
| 587 | /// ^bb0( ... ): |
| 588 | /// linalg.yield %pad |
| 589 | /// } : tensor<?x?xf32> to tensor<4x8xf32> |
| 590 | /// compute(%0) |
| 591 | /// ``` |
| 592 | /// |
| 593 | /// IR resembling the following is produced: |
| 594 | /// |
| 595 | /// ``` |
| 596 | /// scf.for (%i) { |
| 597 | /// %packed_init = tensor.empty range(%j) : tensor<?x4x8xf32> |
| 598 | /// %packed = scf.for (%k) iter_args(%p : %packed_init) { |
| 599 | /// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32> |
| 600 | /// %0 = tensor.pad %st0 low[0, 0] high[...] { |
| 601 | /// ^bb0( ... ): |
| 602 | /// linalg.yield %pad |
| 603 | /// } : tensor<?x?xf32> to tensor<4x8xf32> |
| 604 | /// %1 = tensor.insert_slice %0 ... |
| 605 | /// : tensor<4x8xf32> to tensor<?x4x8xf32> |
| 606 | /// scf.yield %1: tensor<?x4x8xf32> |
| 607 | /// } -> tensor<?x4x8xf32> |
| 608 | /// scf.for (%j, %k) { |
| 609 | /// %st0 = tensor.extract_slice %packed [%k, 0, 0][1, 4, 8][1, 1, 1] : |
| 610 | /// tensor<?x4x8xf32> to tensor<4x8xf32> |
| 611 | /// compute(%st0) |
| 612 | /// } |
| 613 | /// } |
| 614 | /// ``` |
| 615 | FailureOr<Value> |
| 616 | hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, |
| 617 | int64_t numLoops, ArrayRef<int64_t> transposeVector, |
| 618 | tensor::PadOp &hoistedOp, |
| 619 | SmallVectorImpl<TransposeOp> &transposeOps); |
| 620 | /// Calls into `hoistPaddingOnTensors` with a local IRRewriter. |
| 621 | FailureOr<Value> |
| 622 | hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops, |
| 623 | ArrayRef<int64_t> transposeVector, |
| 624 | tensor::PadOp &hoistedOp, |
| 625 | SmallVectorImpl<TransposeOp> &transposeOps); |
| 626 | |
| 627 | /// Apply padding and hoisting to `linalgOp` according to the configuration |
| 628 | /// specified in `options`. |
| 629 | FailureOr<LinalgOp> padAndHoistLinalgOp(RewriterBase &rewriter, |
| 630 | LinalgOp linalgOp, |
| 631 | const LinalgPaddingOptions &options); |
| 632 | |
| 633 | /// Split the given `op` into two parts along the given iteration space |
| 634 | /// `dimension` at the specified `splitPoint`, and return the two parts. |
| 635 | /// If the second part is statically known to be empty, do not create it |
| 636 | /// and return nullptr instead. Error state is signalled by returning |
| 637 | /// a pair of nullptrs. |
| 638 | /// |
| 639 | /// For example, the following op: |
| 640 | /// |
| 641 | /// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>) |
| 642 | /// outs(%2 : tensor<128x64xf32>) |
| 643 | /// |
| 644 | /// split along the first dimension at position 42 will result in: |
| 645 | /// |
| 646 | /// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1] |
| 647 | /// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1] |
| 648 | /// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>) |
| 649 | /// outs(%5 : tensor<42x64xf32>) |
| 650 | /// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1] |
| 651 | /// |
| 652 | /// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1] |
| 653 | /// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1] |
| 654 | /// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>) |
| 655 | /// outs(%8 : tensor<86x64xf32>) |
| 656 | /// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1] |
| 657 | /// |
| 658 | /// Note that there is no simplification other than constant propagation applied |
| 659 | /// to slice extraction and insertion. |
| 660 | std::pair<TilingInterface, TilingInterface> splitOp(RewriterBase &rewriter, |
| 661 | TilingInterface op, |
| 662 | unsigned dimension, |
| 663 | OpFoldResult splitPoint); |
| 664 | |
| 665 | /// Perform standalone tiling of a single LinalgOp by `tileSizes`. |
| 666 | /// and permute the loop nest according to `interchangeVector` |
| 667 | /// The permutation is expressed as a list of integers that specify |
| 668 | /// the new ordering of the loop nest. The length of `interchangeVector` |
| 669 | /// must be equal to the length of `tileSizes`. |
| 670 | /// An empty vector is interpreted as the identity permutation and the |
| 671 | /// transformation returns early. |
| 672 | /// |
| 673 | /// Return a struct containing the tiled loops in the specified order |
| 674 | /// and the cloned op if successful, std::nullopt otherwise. |
| 675 | /// |
| 676 | /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by |
| 677 | /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be |
| 678 | /// integers, in the range 0..`tileSizes.size()` without duplications |
| 679 | /// (i.e. `[1,1,2]` is an invalid permutation). |
| 680 | struct TiledLinalgOp { |
| 681 | LinalgOp op; |
| 682 | SmallVector<Operation *, 8> loops; |
| 683 | SmallVector<Value, 4> tensorResults; |
| 684 | }; |
| 685 | FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op, |
| 686 | const LinalgTilingOptions &options); |
| 687 | |
| 688 | /// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts |
| 689 | /// the index accesses of `op`. This is an in-place transformation controlled |
| 690 | /// by `interchangeVector`. An empty vector is interpreted as the identity |
| 691 | /// permutation and the transformation returns early. |
| 692 | /// |
| 693 | /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with |
| 694 | /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be |
| 695 | /// integers, in the range 0..`op.rank` without duplications |
| 696 | /// (i.e. `[1,1,2]` is an invalid permutation). |
| 697 | /// |
| 698 | /// Return failure if the permutation is not valid. |
| 699 | FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter, |
| 700 | GenericOp genericOp, |
| 701 | ArrayRef<unsigned> interchangeVector); |
| 702 | |
| 703 | /// Create a GenericOp from the given named operation `linalgOp` and replace |
| 704 | /// the given `linalgOp`. |
| 705 | /// Return failure if `linalgOp` is a GenericOp or misses a region builder. |
| 706 | FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter, |
| 707 | LinalgOp linalgOp); |
| 708 | |
| 709 | /// Create a namedOp from the given GenericOp and replace the GenericOp. |
| 710 | /// Currently we can specialize only trivial linalg copy operations. |
| 711 | FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter, |
| 712 | GenericOp genericOp); |
| 713 | |
| 714 | /// Create a new buffer using the `allocationFn` provided. The size of this |
| 715 | /// buffer is the smallest constant bounding size along each dimension that |
| 716 | /// can be computed for the size of the result of `subView`. Returns the |
| 717 | /// allocated buffer as `fullLocalView` and the view that matches the size of |
| 718 | /// the result of subview operation as `partialLocalView`. |
| 719 | struct PromotionInfo { |
| 720 | Value fullLocalView; |
| 721 | Value partialLocalView; |
| 722 | }; |
| 723 | FailureOr<PromotionInfo> |
| 724 | promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, |
| 725 | const AllocBufferCallbackFn &allocationFn, |
| 726 | DataLayout &layout); |
| 727 | |
| 728 | /// Promote the `subViews` into a new buffer allocated at the insertion point |
| 729 | /// `b`. Promotion occurs in 3 steps: |
| 730 | /// 1. Create a new buffer for a full tile (i.e. not clipped at the |
| 731 | /// boundary). |
| 732 | /// 2. Take a full view on the buffer. |
| 733 | /// 3. Take a partial slice of the full view in step 2. and copy into it. |
| 734 | /// |
| 735 | /// Return the modified linalg op (the modification happens in place) as well |
| 736 | /// as all the copy ops created. |
| 737 | FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op, |
| 738 | const LinalgPromotionOptions &options); |
| 739 | |
| 740 | /// Allocate the subview in the GPU workgroup memory. |
| 741 | std::optional<Value> allocateWorkgroupMemory(OpBuilder &builder, |
| 742 | memref::SubViewOp subview, |
| 743 | ArrayRef<Value> sizeBounds, |
| 744 | DataLayout &); |
| 745 | |
| 746 | /// In case of GPU group memory there is no need to deallocate. |
| 747 | LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/); |
| 748 | |
| 749 | /// Create Memref copy operations and add gpu barrier guards before and after |
| 750 | /// the copy operation to ensure data integrity. |
| 751 | LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst); |
| 752 | |
| 753 | /// Allocate the subview in the GPU private memory. |
| 754 | std::optional<Value> allocateGPUPrivateMemory(OpBuilder &builder, |
| 755 | memref::SubViewOp subview, |
| 756 | ArrayRef<Value> sizeBounds, |
| 757 | DataLayout &); |
| 758 | |
| 759 | /// Normal copy to between src and dst. |
| 760 | LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst); |
| 761 | |
| 762 | /// In case of GPU private memory there is no need to deallocate since the |
| 763 | /// memory is freed when going outside of the scope. |
| 764 | LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/); |
| 765 | |
| 766 | /// Return true if there's dedicated logic in the Linalg Vectorizer to |
| 767 | /// vectorize this Op, false otherwise. |
| 768 | /// |
| 769 | /// Note that this helper merely implements a very high level check and that the |
| 770 | /// vectorizer also requires various additional pre-conditions to be met for it |
| 771 | /// to work (these are checked by the vectorizer itself). |
| 772 | bool hasVectorizationImpl(Operation *); |
| 773 | |
| 774 | /// Emit a suitable vector form for an operation. If provided, |
| 775 | /// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes` |
| 776 | /// must match the rank of the iteration space of the operation and the sizes |
| 777 | /// must be smaller or equal than their counterpart interation space sizes, if |
| 778 | /// static. `inputVectorShapes` also allows the vectorization of operations with |
| 779 | /// dynamic shapes. |
| 780 | LogicalResult vectorize(RewriterBase &rewriter, Operation *op, |
| 781 | ArrayRef<int64_t> inputVectorSizes = {}, |
| 782 | ArrayRef<bool> inputScalableVecDims = {}, |
| 783 | bool = false, |
| 784 | bool flatten1DDepthwiseConv = false); |
| 785 | |
| 786 | /// Emit a suitable vector form for a Copy op with fully static shape. |
| 787 | LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); |
| 788 | |
| 789 | /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`. |
| 790 | FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter, |
| 791 | LinalgOp linalgOp); |
| 792 | |
| 793 | /// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`. |
| 794 | FailureOr<LinalgLoops> linalgOpToParallelLoops(RewriterBase &rewriter, |
| 795 | LinalgOp linalgOp); |
| 796 | |
| 797 | /// Emit a loop nest of `affine.for` with the proper body for `linalgOp`. |
| 798 | FailureOr<LinalgLoops> linalgOpToAffineLoops(RewriterBase &rewriter, |
| 799 | LinalgOp linalgOp); |
| 800 | |
| 801 | /// Creates a number of ranges equal to the number of non-zero in `tileSizes`. |
| 802 | /// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument |
| 803 | /// has one entry per surrounding loop. It uses zero as the convention that a |
| 804 | /// particular loop is not tiled. This convention simplifies implementations |
| 805 | /// by avoiding affine map manipulations. The returned ranges correspond to |
| 806 | /// the loop ranges, in the proper order, that are tiled and for which new |
| 807 | /// loops will be created. Also the function returns a map from loop indices |
| 808 | /// of the LinalgOp to the corresponding non-empty range indices of newly |
| 809 | /// created loops. |
| 810 | using LoopIndexToRangeIndexMap = DenseMap<int, int>; |
| 811 | std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> |
| 812 | makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, |
| 813 | ArrayRef<OpFoldResult> allShapeSizes, |
| 814 | ArrayRef<OpFoldResult> allTileSizes); |
| 815 | |
| 816 | namespace detail { |
| 817 | template <typename T> |
| 818 | struct MultiSizeSpecificationBase { |
| 819 | /// Tile sizes. |
| 820 | T lowTileSize, highTileSize; |
| 821 | /// Number of tiles associated with each size. |
| 822 | T lowTripCount, highTripCount; |
| 823 | }; |
| 824 | |
| 825 | template <typename T> |
| 826 | struct ContinuousTileSizeSpecificationBase { |
| 827 | /// Tile sizes. |
| 828 | SmallVector<T> tileSizes; |
| 829 | /// Number of tiles associated with each size. |
| 830 | SmallVector<T> tripCounts; |
| 831 | }; |
| 832 | |
| 833 | } // namespace detail |
| 834 | |
| 835 | /// A description of a multi-size tiling comprising tile sizes and numbers of |
| 836 | /// tiles, expressed as Values which may or may not be constant. Multi-size |
| 837 | /// currently means two-size. |
| 838 | struct MultiSizeSpecification |
| 839 | : public detail::MultiSizeSpecificationBase<Value> {}; |
| 840 | struct StaticMultiSizeSpecification |
| 841 | : public detail::MultiSizeSpecificationBase<int64_t> {}; |
| 842 | |
| 843 | struct ContinuousTileSizeSpecification |
| 844 | : public detail::ContinuousTileSizeSpecificationBase<Value> {}; |
| 845 | struct StaticContinuousTileSizeSpecification |
| 846 | : public detail::ContinuousTileSizeSpecificationBase<int64_t> {}; |
| 847 | |
| 848 | /// Emits the IR computing the multi-sized tiling specification with two tile |
| 849 | /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such |
| 850 | /// that there exist numbers of tiles with these sizes that fully cover the |
| 851 | /// given iteration space `dimension` of the structured `op`. |
| 852 | /// |
| 853 | /// The computation is as follows: |
| 854 | /// |
| 855 | /// b = originalTripCount floordiv sizeDivisor |
| 856 | /// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor |
| 857 | /// d = (b + t - 1) floordiv t |
| 858 | /// s = (b floordiv d) * sizeDivisor |
| 859 | /// v = b % d |
| 860 | /// u = d - v |
| 861 | /// |
| 862 | /// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of |
| 863 | /// the corresponding tiles are `u` and `v`, respectively. Alternatively, |
| 864 | /// |
| 865 | /// s * u + (s + sizeDivisor) * v == original size, |
| 866 | /// where s mod sizeDivisor = 0. |
| 867 | /// |
| 868 | /// Expects all values to be positive. In some cases with the target tile size |
| 869 | /// sufficiently close to the dimension shape and non-unit divisor, it is |
| 870 | /// impossible to compute such sizes. If `emitAssertion` is set, also emit the |
| 871 | /// assertion that size computation succeeded. |
| 872 | /// |
| 873 | /// Returns the specification consisting of both tile values and the number of |
| 874 | /// tiles of each size. |
| 875 | FailureOr<MultiSizeSpecification> |
| 876 | computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, |
| 877 | OpFoldResult targetSize, OpFoldResult divisor, |
| 878 | bool emitAssertions = true); |
| 879 | FailureOr<StaticMultiSizeSpecification> |
| 880 | computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, |
| 881 | int64_t divisor); |
| 882 | |
| 883 | FailureOr<StaticContinuousTileSizeSpecification> |
| 884 | computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, |
| 885 | unsigned targetSize); |
| 886 | FailureOr<ContinuousTileSizeSpecification> |
| 887 | computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, |
| 888 | unsigned dimension, OpFoldResult targetSize, |
| 889 | bool emitAssertions); |
| 890 | |
| 891 | /// Transformation information returned after reduction tiling. |
| 892 | struct ForallReductionTilingResult { |
| 893 | /// The partial reduction tiled op generated. |
| 894 | SmallVector<Operation *> parallelTiledOps; |
| 895 | /// The final reduction operation merging all the partial reductions. |
| 896 | SmallVector<Operation *> mergeOps; |
| 897 | /// Initial values used for partial reductions. |
| 898 | SmallVector<Value> initialValues; |
| 899 | /// The `scf.forall` operation that iterate over the tiles. |
| 900 | scf::ForallOp loops; |
| 901 | }; |
| 902 | |
| 903 | /// Method to tile a reduction to parallel iterations computing partial |
| 904 | /// reductions. After the loop all the partial reduction are merged into a final |
| 905 | /// reduction. For example for the following sequence |
| 906 | /// |
| 907 | /// ```mlir |
| 908 | /// %0 = linalg.generic %in ["parallel", "reduction"] |
| 909 | /// : tensor<7x9xf32> -> tensor<7xf32> |
| 910 | /// ``` |
| 911 | /// |
| 912 | /// into: |
| 913 | /// |
| 914 | /// ```mlir |
| 915 | /// %0 = linalg.fill ... : tensor<7x4xf32> |
| 916 | /// %1 = scf.forall (%iv) in (%c4) shared_outs(%arg0 = %0) |
| 917 | /// -> (tensor<7x4xf32>) { |
| 918 | /// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32> |
| 919 | /// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> |
| 920 | /// %4 = linalg.generic %2, %3 ["parallel", "reduction"] |
| 921 | /// : tensor<7x?xf32> -> tensor<7xf32> |
| 922 | /// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32> |
| 923 | /// } |
| 924 | /// %6 = linalg.generic %1 ["parallel", "reduction"] |
| 925 | /// : tensor<7x4xf32> -> tensor<7xf32> |
| 926 | /// ``` |
| 927 | FailureOr<ForallReductionTilingResult> |
| 928 | tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, |
| 929 | ArrayRef<OpFoldResult> numThreads, |
| 930 | ArrayRef<OpFoldResult> tileSizes = {}, |
| 931 | std::optional<ArrayAttr> mapping = std::nullopt); |
| 932 | |
| 933 | /// All indices returned by IndexOp should be invariant with respect to |
| 934 | /// tiling. Therefore, if an operation is tiled, we have to transform the |
| 935 | /// indices accordingly, i.e. offset them by the values of the corresponding |
| 936 | /// induction variables that are captured implicitly in the body of the op. |
| 937 | /// |
| 938 | /// Example. `linalg.generic` before tiling: |
| 939 | /// |
| 940 | /// #id_2d = (i, j) -> (i, j) |
| 941 | /// #pointwise_2d_trait = { |
| 942 | /// indexing_maps = [#id_2d, #id_2d], |
| 943 | /// iterator_types = ["parallel", "parallel"] |
| 944 | /// } |
| 945 | /// linalg.generic #pointwise_2d_trait %operand, %result { |
| 946 | /// ^bb0(%operand_in: f32, %result_in: f32): |
| 947 | /// %i = linalg.index 0 : index |
| 948 | /// %j = linalg.index 1 : index |
| 949 | /// <some operations that use %i, %j> |
| 950 | /// }: memref<50x100xf32>, memref<50x100xf32> |
| 951 | /// |
| 952 | /// After tiling pass with tiles sizes 10 and 25: |
| 953 | /// |
| 954 | /// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) |
| 955 | /// |
| 956 | /// %c1 = arith.constant 1 : index |
| 957 | /// %c0 = arith.constant 0 : index |
| 958 | /// %c25 = arith.constant 25 : index |
| 959 | /// %c10 = arith.constant 10 : index |
| 960 | /// operand_dim_0 = dim %operand, 0 : memref<50x100xf32> |
| 961 | /// operand_dim_1 = dim %operand, 1 : memref<50x100xf32> |
| 962 | /// scf.for %k = %c0 to operand_dim_0 step %c10 { |
| 963 | /// scf.for %l = %c0 to operand_dim_1 step %c25 { |
| 964 | /// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1] |
| 965 | /// : memref<50x100xf32> to memref<?x?xf32, #strided> |
| 966 | /// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1] |
| 967 | /// : memref<50x100xf32> to memref<?x?xf32, #strided> |
| 968 | /// linalg.generic pointwise_2d_trait %4, %5 { |
| 969 | /// ^bb0(%operand_in: f32, %result_in: f32): |
| 970 | /// %i = linalg.index 0 : index |
| 971 | /// %j = linalg.index 1 : index |
| 972 | /// // Indices `k` and `l` are implicitly captured in the body. |
| 973 | /// %transformed_i = arith.addi %i, %k : index // index `i` is offset by |
| 974 | /// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset |
| 975 | /// by %l |
| 976 | /// // Every use of %i, %j is replaced with %transformed_i, |
| 977 | /// %transformed_j <some operations that use %transformed_i, |
| 978 | /// %transformed_j> |
| 979 | /// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> |
| 980 | /// } |
| 981 | /// } |
| 982 | /// |
| 983 | /// TODO: Investigate whether mixing implicit and explicit indices |
| 984 | /// does not lead to losing information. |
| 985 | void transformIndexOps(RewriterBase &b, LinalgOp op, |
| 986 | SmallVectorImpl<Value> &ivs, |
| 987 | const LoopIndexToRangeIndexMap &loopIndexToRangeIndex); |
| 988 | |
| 989 | /// Apply transformation to split the single linalg op reduction into a |
| 990 | /// parallel and reduction dimension. Then create a new linalg.generic op |
| 991 | /// doing the rest of the reduction. Return the new linalg op with an extra |
| 992 | /// parallel dimension or failure if the transformation didn't happen. |
| 993 | /// |
| 994 | /// Example: |
| 995 | /// ``` |
| 996 | /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, |
| 997 | /// affine_map<(d0) -> ()>], |
| 998 | /// iterator_types = ["reduction"]} |
| 999 | /// ins(%in : tensor<32xf32>) |
| 1000 | /// outs(%out : tensor<f32>) { |
| 1001 | /// ^bb0(%arg1: f32, %arg2: f32): |
| 1002 | /// %y = arith.addf %arg1, %arg2 : f32 |
| 1003 | /// linalg.yield %y : f32 |
| 1004 | /// } -> tensor<f32> |
| 1005 | /// ``` |
| 1006 | /// To: |
| 1007 | /// ``` |
| 1008 | /// %cst = arith.constant 0.000000e+00 : f32 |
| 1009 | /// %0 = tensor.expand_shape %in [[0, 1]]: tensor<32xf32> into tensor<4x8xf32> |
| 1010 | /// %1 = tensor.empty [4] : tensor<4xf32> |
| 1011 | /// %2 = linalg.fill ins(%cst : f32) |
| 1012 | /// outs(%1 : tensor<4xf32>) -> tensor<4xf32> |
| 1013 | /// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, |
| 1014 | /// affine_map<(d0, d1) -> (d0)>], |
| 1015 | /// iterator_types = ["parallel", "reduction"]} |
| 1016 | /// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) { |
| 1017 | /// ^bb0(%arg3: f32, %arg5: f32): |
| 1018 | /// %5 = arith.addf %arg3, %arg4 : f32 |
| 1019 | /// linalg.yield %5 : f32 |
| 1020 | /// } -> tensor<4xf32> |
| 1021 | /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, |
| 1022 | /// affine_map<(d0) -> ()>], |
| 1023 | /// iterator_types = ["reduction"]} |
| 1024 | /// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) { |
| 1025 | /// ^bb0(%arg3: f32, %arg4: f32): |
| 1026 | /// %5 = arith.addf %arg3, %arg4 : f32 |
| 1027 | /// linalg.yield %5 : f32 |
| 1028 | /// } -> tensor<f32> |
| 1029 | /// ``` |
| 1030 | struct SplitReductionResult { |
| 1031 | Operation *initOrAlloc; |
| 1032 | FillOp fillOp; |
| 1033 | LinalgOp splitLinalgOp; |
| 1034 | LinalgOp resultCombiningLinalgOp; |
| 1035 | }; |
| 1036 | FailureOr<SplitReductionResult> |
| 1037 | splitReduction(RewriterBase &b, LinalgOp op, |
| 1038 | const ControlSplitReductionFn &controlSplitReductionFn, |
| 1039 | bool useAlloc = false); |
| 1040 | |
| 1041 | /// Scaling-based implementation of the split reduction transformation. |
| 1042 | /// Instead of introducing an ExpandShapeOp, this rewrites a reduction |
| 1043 | /// dimension `k` into `k * scale + kk`. |
| 1044 | /// |
| 1045 | /// Example: |
| 1046 | /// ``` |
| 1047 | /// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) |
| 1048 | /// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> |
| 1049 | /// ``` |
| 1050 | /// |
| 1051 | /// Is transformed to: |
| 1052 | /// |
| 1053 | /// ``` |
| 1054 | /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)> |
| 1055 | /// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)> |
| 1056 | /// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> |
| 1057 | /// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> |
| 1058 | /// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> |
| 1059 | /// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> |
| 1060 | /// %0 = tensor.empty [16, 32, 64] : tensor<16x32x64xf32> |
| 1061 | /// %cst = arith.constant 0.000000e+00 : f32 |
| 1062 | /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> |
| 1063 | /// tensor<16x32x64xf32> |
| 1064 | /// %2 = tensor.empty [64, 4] : tensor<64x4xi1> |
| 1065 | /// |
| 1066 | /// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], |
| 1067 | /// iterator_types = ["parallel", "parallel", "parallel", "reduction"]} |
| 1068 | /// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, |
| 1069 | /// tensor<64x4xi1>) |
| 1070 | /// outs(%1 : tensor<16x32x64xf32>) { |
| 1071 | /// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32): |
| 1072 | /// %5 = arith.mulf %arg3, %arg4 : f32 |
| 1073 | /// %6 = arith.addf %arg6, %5 : f32 |
| 1074 | /// linalg.yield %6 : f32 |
| 1075 | /// } -> tensor<16x32x64xf32> |
| 1076 | /// |
| 1077 | /// %4 = linalg.generic {indexing_maps = [#map4, #map5], |
| 1078 | /// iterator_types = ["parallel", "parallel", "reduction"]} |
| 1079 | // ins(%3 : tensor<16x32x64xf32>) |
| 1080 | /// outs(%C : tensor<16x32xf32>) { |
| 1081 | /// ^bb0(%arg3: f32, %arg4: f32): |
| 1082 | /// %5 = arith.addf %arg3, %arg4 : f32 |
| 1083 | /// linalg.yield %5 : f32 |
| 1084 | /// } -> tensor<16x32xf32> |
| 1085 | /// |
| 1086 | /// return %4 : tensor<16x32xf32> |
| 1087 | /// ``` |
| 1088 | FailureOr<SplitReductionResult> |
| 1089 | splitReductionByScaling(RewriterBase &b, LinalgOp op, |
| 1090 | const ControlSplitReductionFn &controlSplitReductionFn, |
| 1091 | bool useAlloc = false); |
| 1092 | |
| 1093 | /// Return `true` if a given sequence of dimensions are contiguous in the |
| 1094 | /// range of the specified indexing map. |
| 1095 | bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence); |
| 1096 | /// Return `true` if all sequences of dimensions specified in `dimSequences` are |
| 1097 | /// contiguous in all the ranges of the `maps`. |
| 1098 | bool areDimSequencesPreserved(ArrayRef<AffineMap> maps, |
| 1099 | ArrayRef<ReassociationIndices> dimSequences); |
| 1100 | |
| 1101 | struct CollapseResult { |
| 1102 | SmallVector<Value> results; |
| 1103 | LinalgOp collapsedOp; |
| 1104 | }; |
| 1105 | |
| 1106 | /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition |
| 1107 | /// to calling this method is that for each list in `foldedIterationDim`, the |
| 1108 | /// sequence of dimensions is contiguous in domains of all `indexing_maps` of |
| 1109 | /// the `linalgOp`. This can be checked using `areDimSequencePreserved` method. |
| 1110 | /// When valid, the method also collapses the operands of the op. Returns |
| 1111 | /// replacement values of the results of the original `linalgOp` by inserting |
| 1112 | /// reshapes to get back values of compatible types. |
| 1113 | FailureOr<CollapseResult> |
| 1114 | collapseOpIterationDims(LinalgOp op, |
| 1115 | ArrayRef<ReassociationIndices> foldedIterationDims, |
| 1116 | RewriterBase &rewriter); |
| 1117 | |
| 1118 | struct LowerPackResult { |
| 1119 | tensor::PadOp padOp; |
| 1120 | tensor::ExpandShapeOp expandShapeOp; |
| 1121 | linalg::TransposeOp transposeOp; |
| 1122 | }; |
| 1123 | |
| 1124 | /// Rewrite pack as pad + reshape + transpose. |
| 1125 | FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter, |
| 1126 | linalg::PackOp packOp, |
| 1127 | bool lowerPadLikeWithInsertSlice = true); |
| 1128 | |
| 1129 | struct LowerUnPackOpResult { |
| 1130 | tensor::EmptyOp emptyOp; |
| 1131 | linalg::TransposeOp transposeOp; |
| 1132 | tensor::CollapseShapeOp collapseShapeOp; |
| 1133 | tensor::ExtractSliceOp ; |
| 1134 | }; |
| 1135 | |
| 1136 | /// Rewrite pack as empty + transpose + reshape + extract_slice. |
| 1137 | FailureOr<LowerUnPackOpResult> |
| 1138 | lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, |
| 1139 | bool = true); |
| 1140 | |
| 1141 | /// Struct to hold the result of a `pack` call. |
| 1142 | struct PackResult { |
| 1143 | SmallVector<linalg::PackOp> packOps; |
| 1144 | linalg::LinalgOp packedLinalgOp; |
| 1145 | SmallVector<linalg::UnPackOp> unPackOps; |
| 1146 | }; |
| 1147 | /// Implement packing of a single LinalgOp by `packedSizes`. |
| 1148 | /// There must be one packedSizes entry per `linalgOp` iterator. |
| 1149 | /// Return the packed Linalg op on success, failure otherwise. |
| 1150 | FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, |
| 1151 | ArrayRef<OpFoldResult> packedSizes); |
| 1152 | |
| 1153 | /// Struct to hold the result of a `packTranspose` call. |
| 1154 | struct PackTransposeResult { |
| 1155 | linalg::PackOp transposedPackOp; |
| 1156 | linalg::LinalgOp transposedLinalgOp; |
| 1157 | linalg::UnPackOp transposedUnPackOp; |
| 1158 | }; |
| 1159 | /// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the |
| 1160 | /// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements. |
| 1161 | /// Return failure if either: |
| 1162 | /// 1. the `packOp` does not have the `linalgOp` as its unique use. |
| 1163 | /// 2. the `maybeUnPackOp`, if specified must be a consumer of the result tied |
| 1164 | /// to the unique `packOp` use. |
| 1165 | /// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of |
| 1166 | /// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty. |
| 1167 | FailureOr<PackTransposeResult> |
| 1168 | packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, |
| 1169 | linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, |
| 1170 | ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm); |
| 1171 | |
| 1172 | /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m |
| 1173 | /// and n are proper parallel dimensions and k is a proper reduction |
| 1174 | /// dimension. Packing occurs by rewriting the op as a linalg.generic and |
| 1175 | /// calling linalg::pack by `mnkPackedSizes`. The order of the packed |
| 1176 | /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2} |
| 1177 | /// to reorder {m, n, k} into one of the 8 possible forms. The outer |
| 1178 | /// dimensions of the operands are not permuted at this time, this is left for |
| 1179 | /// future work. |
| 1180 | FailureOr<PackResult> |
| 1181 | packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, |
| 1182 | ArrayRef<OpFoldResult> mnkPackedSizes, |
| 1183 | ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf, |
| 1184 | ArrayRef<int64_t> mnkOrder); |
| 1185 | |
| 1186 | struct BlockPackMatmulOptions { |
| 1187 | /// Minor block factors (mb, nb, kb) for packing relayout where mb, mn are |
| 1188 | /// the parallel dimensions and kb is the reduction dimension. |
| 1189 | SmallVector<int64_t, 3> blockFactors; |
| 1190 | |
| 1191 | /// If true, allows packing of dimensions that only partially fit into the |
| 1192 | /// block factors. |
| 1193 | bool allowPadding = true; |
| 1194 | |
| 1195 | /// Next multiples of the packing sizes. |
| 1196 | SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf; |
| 1197 | |
| 1198 | /// Permutation of matmul (M, N, K) dimensions order. |
| 1199 | SmallVector<int64_t, 3> mnkOrder = {0, 1, 2}; |
| 1200 | |
| 1201 | /// Transpose LHS outer block layout [MB][KB] -> [KB][MB]. |
| 1202 | bool lhsTransposeOuterBlocks = false; |
| 1203 | |
| 1204 | /// Transpose LHS inner block layout [mb][kb] -> [kb][mb]. |
| 1205 | bool lhsTransposeInnerBlocks = false; |
| 1206 | |
| 1207 | /// Transpose RHS outer block layout [KB][NB] -> [NB][KB]. |
| 1208 | bool rhsTransposeOuterBlocks = true; |
| 1209 | |
| 1210 | /// Transpose RHS inner block layout [kb][nb] -> [nb][kb]. |
| 1211 | bool rhsTransposeInnerBlocks = true; |
| 1212 | }; |
| 1213 | |
| 1214 | /// Function type which is used to control matmul packing. |
| 1215 | /// It is expected to return valid packing configuration for each operation. |
| 1216 | /// Lack of packing options indicates that no valid configuration could be |
| 1217 | /// assigned and the operation will not be packed. |
| 1218 | using ControlBlockPackMatmulFn = |
| 1219 | std::function<std::optional<BlockPackMatmulOptions>(linalg::LinalgOp)>; |
| 1220 | |
| 1221 | /// Pack a matmul operation into blocked 4D layout. |
| 1222 | /// |
| 1223 | /// Relayout a matmul operation into blocked layout with two levels of |
| 1224 | /// subdivision: |
| 1225 | /// - major 2D blocks - outer dimensions, consist of minor blocks |
| 1226 | /// - minor 2D blocks - inner dimensions, consist of scalar elements |
| 1227 | /// |
| 1228 | /// A 2D matmul MxNxK gets reshaped into blocked 4D representation |
| 1229 | /// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb] |
| 1230 | /// where the (MB, NB, KB) dimensions represent the major blocks, |
| 1231 | /// and the (mb, nb, kb) are the minor blocks of their respective |
| 1232 | /// original 2D dimensions (M, N, K). |
| 1233 | /// |
| 1234 | /// Depending on the initial operands' data layout and the specified |
| 1235 | /// packing options, the major blocks dimensions might get transposed |
| 1236 | /// e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed |
| 1237 | /// e.g., [mb][kb] -> [kb][mb]. |
| 1238 | /// Any present batch dimensions remain unchanged. |
| 1239 | /// The final result is unpacked back to the original shape. |
| 1240 | /// |
| 1241 | /// Return failure if no valid packing options are provided. |
| 1242 | FailureOr<PackResult> |
| 1243 | blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, |
| 1244 | const ControlBlockPackMatmulFn &controlPackMatmul); |
| 1245 | |
| 1246 | /// Rewrite tensor.from_elements to linalg.generic. |
| 1247 | FailureOr<Operation *> |
| 1248 | rewriteInDestinationPassingStyle(RewriterBase &rewriter, |
| 1249 | tensor::FromElementsOp fromElementsOp); |
| 1250 | |
| 1251 | /// Rewrite tensor.generate to linalg.generic. |
| 1252 | FailureOr<Operation *> |
| 1253 | rewriteInDestinationPassingStyle(RewriterBase &rewriter, |
| 1254 | tensor::GenerateOp generateOp); |
| 1255 | |
| 1256 | /// Rewrite tensor.pad to linalg.generic + tensor.insert_slice. |
| 1257 | FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter, |
| 1258 | tensor::PadOp padOp); |
| 1259 | |
| 1260 | /// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) |
| 1261 | /// and linalg.matmul. |
| 1262 | /// |
| 1263 | /// A convolution operation can be written as a matrix-matrix multiplication by |
| 1264 | /// unfolding the cross-correlation between input and filter and explicitly copy |
| 1265 | /// overlapped sliding window inputs. |
| 1266 | /// |
| 1267 | /// Consider 2D input X with single channel input and output and 2x2 filter W: |
| 1268 | /// [x(0, 0) , x(0, 1) , ..., x(0, n) ] |
| 1269 | /// [x(1, 0) , x(1, 1) , ..., x(1, n) ] |
| 1270 | /// [. , . ,. , . ] [w(0, 0), w(0, 1)] |
| 1271 | /// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)] |
| 1272 | /// [. , . , ., . ] |
| 1273 | /// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)] |
| 1274 | /// |
| 1275 | /// The packed input data (img2col) is a matrix with |rows| = output spatial |
| 1276 | /// size, |columns| = filter spatial size. To compute the output Y(i, j) we need |
| 1277 | /// to calculate the dot product between filter window at input X(x, y)) and the |
| 1278 | /// filter which will look like the following where r.h.s is the img2col matrix |
| 1279 | /// and l.h.s is the flattened filter: |
| 1280 | /// |
| 1281 | /// [x(0,0), x(0,1), x(1,0), x(1,1)] |
| 1282 | /// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)] |
| 1283 | /// [x(0,1), x(1,1), x(0,2), x(1,2)] |
| 1284 | /// [ . , . , . , . ] |
| 1285 | /// |
| 1286 | /// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter |
| 1287 | /// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix |
| 1288 | /// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in |
| 1289 | /// the N input. For the case where N > 1 its a batched matrix-matrix |
| 1290 | /// multiplication. |
| 1291 | /// |
| 1292 | /// On success, return both the operation that produces the img2col tensor and |
| 1293 | /// the final operation of the sequence that replaces the original convolution. |
| 1294 | FailureOr<std::pair<Operation *, Operation *>> |
| 1295 | rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp); |
| 1296 | |
| 1297 | /// Same as the above but for Fhwc channel orderings in the filter. In this case |
| 1298 | /// the matrix multiplication is actually a row-wise dot-product rather than a |
| 1299 | /// row-column dot-product. This is to avoid transposing the filter matrix which |
| 1300 | /// would be required for a regular matrix multiplication to produce the correct |
| 1301 | /// output dimensions. |
| 1302 | FailureOr<std::pair<Operation *, Operation *>> |
| 1303 | rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp); |
| 1304 | |
| 1305 | /// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no |
| 1306 | /// reduction among the input channels so each convolution can be a |
| 1307 | /// matrix-vector product and by transposing both input filter so channels are |
| 1308 | /// outer most the computation is a batched matrix-vector product. |
| 1309 | FailureOr<std::pair<Operation *, Operation *>> |
| 1310 | rewriteInIm2Col(RewriterBase &rewriter, |
| 1311 | linalg::DepthwiseConv2DNhwcHwcOp convOp); |
| 1312 | |
| 1313 | /// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the |
| 1314 | /// channels are to the left of the image shape dimensions, the position of the |
| 1315 | /// contraction dimension in the resulting matmul is reversed. This swaps the |
| 1316 | /// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) * |
| 1317 | /// (C x Kh x Kw, Ho x Wo)) |
| 1318 | FailureOr<std::pair<Operation *, Operation *>> |
| 1319 | rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp); |
| 1320 | |
| 1321 | /// Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by |
| 1322 | /// materializing transpose. |
| 1323 | FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, |
| 1324 | linalg::Conv2DNhwcFhwcOp op); |
| 1325 | FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, |
| 1326 | linalg::Conv2DNhwcFhwcQOp op); |
| 1327 | |
| 1328 | /// Convert Linalg matmul ops to transposed variants. |
| 1329 | FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter, |
| 1330 | linalg::MatmulOp op, |
| 1331 | bool transposeLHS = true); |
| 1332 | FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter, |
| 1333 | linalg::BatchMatmulOp op, |
| 1334 | bool transposeLHS = true); |
| 1335 | |
| 1336 | /// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm |
| 1337 | /// F(m x m, r x r). m is the dimension size of output and r is the dimension |
| 1338 | /// size of filter. |
| 1339 | FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, |
| 1340 | linalg::Conv2DNhwcFhwcOp op, int64_t m, |
| 1341 | int64_t r); |
| 1342 | |
| 1343 | /// Rewrite linalg.winograd_filter_transform. The data layout of the filter is |
| 1344 | /// FHWC. The transformation matrix is 2-dimension. We need to extract H x W |
| 1345 | /// from FHWC first. We generate 2 levels of loops to iterate on F and C. After |
| 1346 | /// the rewriting, we get |
| 1347 | /// |
| 1348 | /// scf.for %f = lo_f to hi_f step 1 |
| 1349 | /// scf.for %c = lo_c to hi_c step 1 |
| 1350 | /// %extracted = extract filter<h x w> from filter<f x h x w x c> |
| 1351 | /// %ret = linalg.matmul G, %extracted |
| 1352 | /// %ret = linalg.matmul %ret, GT |
| 1353 | /// %inserted = insert %ret into filter<h x w x c x f> |
| 1354 | FailureOr<Operation *> |
| 1355 | decomposeWinogradFilterTransformOp(RewriterBase &rewriter, |
| 1356 | linalg::WinogradFilterTransformOp op); |
| 1357 | |
| 1358 | /// Rewrite linalg.winograd_input_transform. The data layout of the input is |
| 1359 | /// NHWC. The transformation matrix is 2-dimension. We need to extract H x W |
| 1360 | /// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH, |
| 1361 | /// and tileW. After the rewriting, we get |
| 1362 | /// |
| 1363 | /// scf.for %h = 0 to tileH step 1 |
| 1364 | /// scf.for %w = 0 to tileW step 1 |
| 1365 | /// scf.for %n = 0 to N step 1 |
| 1366 | /// scf.for %c = 0 to C step 1 |
| 1367 | /// %extracted = extract %extracted<alphaH x alphaW> from |
| 1368 | /// %input<N x H x W x C> |
| 1369 | /// at [%n, (%h x m), (%w x m), %c] |
| 1370 | /// %ret = linalg.matmul BT, %extracted |
| 1371 | /// %ret = linalg.matmul %ret, B |
| 1372 | /// %inserted = insert %ret<alphaH x alphaW> into |
| 1373 | /// %output<alphaH x alphaW x tileH x tileW x N x C> |
| 1374 | /// at [0, 0, %h, %w, %n, %c] |
| 1375 | FailureOr<Operation *> |
| 1376 | decomposeWinogradInputTransformOp(RewriterBase &rewriter, |
| 1377 | linalg::WinogradInputTransformOp op); |
| 1378 | |
| 1379 | /// Rewrite linalg.winograd_output_transform. The data layout of the output is |
| 1380 | /// HWNF. The transformation matrix is 2-dimension. We need to extract H x W |
| 1381 | /// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH, |
| 1382 | /// and tileW. After the transformation, we get |
| 1383 | /// |
| 1384 | /// scf.for %h = 0 to tileH step 1 |
| 1385 | /// scf.for %w = 0 to tileW step 1 |
| 1386 | /// scf.for %n = 0 to N step 1 |
| 1387 | /// scf.for %f = 0 to F step 1 |
| 1388 | /// %extracted = extract %extracted<alphaH x alphaW> from |
| 1389 | /// %input<alphaH x alphaW x tileH x tileW x N x F> |
| 1390 | /// at [0, 0, %h, %w, %n, %f] |
| 1391 | /// %ret = linalg.matmul AT, %extracted |
| 1392 | /// %ret = linalg.matmul %ret, A |
| 1393 | /// %inserted = insert %ret<alphaH x alphaW> into |
| 1394 | /// output<N x H x W x F> |
| 1395 | /// at [%n, (%h x m), (%w x m), %f] |
| 1396 | FailureOr<Operation *> |
| 1397 | decomposeWinogradOutputTransformOp(RewriterBase &rewriter, |
| 1398 | linalg::WinogradOutputTransformOp op); |
| 1399 | |
| 1400 | /// Method to deduplicate operands and remove dead results of `linalg.generic` |
| 1401 | /// operations. This is effectively DCE for a linalg.generic op. If there is |
| 1402 | /// deduplication of operands orremoval of results, replaces the `genericOp` |
| 1403 | /// with a new op and returns it. Returns the same operation if there is no |
| 1404 | /// deduplication/removal. |
| 1405 | FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults( |
| 1406 | RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs); |
| 1407 | |
| 1408 | //===----------------------------------------------------------------------===// |
| 1409 | // Rewrite patterns wrapping transformations. |
| 1410 | // TODO: every single such pattern should be a close to noop wrapper around a |
| 1411 | // functional-stye API call. |
| 1412 | //===----------------------------------------------------------------------===// |
| 1413 | |
| 1414 | /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D |
| 1415 | /// convolution ops. |
| 1416 | template <typename Conv2DOp, typename Conv1DOp> |
| 1417 | struct DownscaleSizeOneWindowed2DConvolution final |
| 1418 | : public OpRewritePattern<Conv2DOp> { |
| 1419 | using OpRewritePattern<Conv2DOp>::OpRewritePattern; |
| 1420 | |
| 1421 | FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp, |
| 1422 | PatternRewriter &rewriter) const; |
| 1423 | |
| 1424 | LogicalResult matchAndRewrite(Conv2DOp convOp, |
| 1425 | PatternRewriter &rewriter) const override { |
| 1426 | return returningMatchAndRewrite(convOp, rewriter); |
| 1427 | } |
| 1428 | }; |
| 1429 | |
| 1430 | extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp, |
| 1431 | Conv1DNwcWcfOp>; |
| 1432 | extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp, |
| 1433 | Conv1DNcwFcwOp>; |
| 1434 | |
| 1435 | /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) |
| 1436 | /// dimensions into 1-D depthwise convolution ops. |
| 1437 | struct DownscaleDepthwiseConv2DNhwcHwcOp final |
| 1438 | : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { |
| 1439 | DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context, |
| 1440 | PatternBenefit benefit = 1) |
| 1441 | : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {} |
| 1442 | |
| 1443 | FailureOr<DepthwiseConv1DNwcWcOp> |
| 1444 | returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, |
| 1445 | PatternRewriter &rewriter) const; |
| 1446 | |
| 1447 | LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, |
| 1448 | PatternRewriter &rewriter) const override { |
| 1449 | return returningMatchAndRewrite(convOp, rewriter); |
| 1450 | } |
| 1451 | }; |
| 1452 | |
| 1453 | struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> { |
| 1454 | DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1) |
| 1455 | : OpRewritePattern<Conv2DOp>(context, benefit) {} |
| 1456 | |
| 1457 | FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp, |
| 1458 | PatternRewriter &rewriter) const; |
| 1459 | |
| 1460 | LogicalResult matchAndRewrite(Conv2DOp convOp, |
| 1461 | PatternRewriter &rewriter) const override { |
| 1462 | return returningMatchAndRewrite(convOp, rewriter); |
| 1463 | } |
| 1464 | }; |
| 1465 | |
| 1466 | /// |
| 1467 | /// Linalg generalization pattern. |
| 1468 | /// |
| 1469 | /// Apply the `generalization` transformation as a pattern. |
| 1470 | /// See `generalization` for more details. |
| 1471 | // |
| 1472 | // TODO: Automatic default pattern class that just unwraps a function |
| 1473 | // returning FailureOr<GenericOp>. |
| 1474 | struct LinalgGeneralizationPattern |
| 1475 | : public OpInterfaceRewritePattern<LinalgOp> { |
| 1476 | using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; |
| 1477 | |
| 1478 | /// `matchAndRewrite` implementation that returns the significant |
| 1479 | /// transformed pieces of IR. |
| 1480 | FailureOr<GenericOp> |
| 1481 | returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const { |
| 1482 | return generalizeNamedOp(rewriter, op); |
| 1483 | } |
| 1484 | |
| 1485 | LogicalResult matchAndRewrite(LinalgOp op, |
| 1486 | PatternRewriter &rewriter) const override { |
| 1487 | return returningMatchAndRewrite(op, rewriter); |
| 1488 | } |
| 1489 | }; |
| 1490 | |
| 1491 | struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> { |
| 1492 | using OpRewritePattern<GenericOp>::OpRewritePattern; |
| 1493 | |
| 1494 | FailureOr<GenericOp> |
| 1495 | returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const { |
| 1496 | return specializeGenericOp(rewriter, op); |
| 1497 | } |
| 1498 | |
| 1499 | LogicalResult matchAndRewrite(GenericOp op, |
| 1500 | PatternRewriter &rewriter) const override { |
| 1501 | return returningMatchAndRewrite(op: op, rewriter); |
| 1502 | } |
| 1503 | }; |
| 1504 | |
| 1505 | /// Vectorization pattern for memref::CopyOp. |
| 1506 | struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> { |
| 1507 | using OpRewritePattern<memref::CopyOp>::OpRewritePattern; |
| 1508 | |
| 1509 | LogicalResult matchAndRewrite(memref::CopyOp copyOp, |
| 1510 | PatternRewriter &rewriter) const override; |
| 1511 | }; |
| 1512 | |
| 1513 | using OptimizeCopyFn = |
| 1514 | std::function<LogicalResult(RewriterBase &, tensor::PadOp, Value)>; |
| 1515 | |
| 1516 | /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and |
| 1517 | /// InsertSliceOp. For now, only constant padding values are supported. |
| 1518 | struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> { |
| 1519 | DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) |
| 1520 | : OpRewritePattern<tensor::PadOp>(context, benefit) {} |
| 1521 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
| 1522 | PatternRewriter &rewriter) const override; |
| 1523 | |
| 1524 | protected: |
| 1525 | Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, |
| 1526 | Value dest, |
| 1527 | const SmallVector<Value> &dynSizes) const; |
| 1528 | }; |
| 1529 | |
| 1530 | /// Rewrites a linalg::PackOp into a sequence of: |
| 1531 | /// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp + |
| 1532 | /// tensor::InsertSliceOp ops. |
| 1533 | /// |
| 1534 | /// Requires that all the outer dims of the input linalg::PackOp are 1. |
| 1535 | /// |
| 1536 | /// Before: |
| 1537 | /// ``` |
| 1538 | /// %packed = linalg.pack %input |
| 1539 | /// padding_value(%pad : f32) |
| 1540 | /// inner_dims_pos = [1, 0] |
| 1541 | /// inner_tiles = [2, %high] |
| 1542 | /// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> |
| 1543 | /// ``` |
| 1544 | /// |
| 1545 | /// After: |
| 1546 | /// ``` |
| 1547 | /// // PadOp |
| 1548 | /// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { |
| 1549 | /// ^bb0(...): |
| 1550 | /// tensor.yield %arg2 : f32 |
| 1551 | /// } : tensor<5x1xf32> to tensor<?x2xf32> |
| 1552 | /// // EmptyOp + TransposeOp |
| 1553 | /// %empty = tensor.empty(%arg3) : tensor<2x?xf32> |
| 1554 | /// %transposed = linalg.transpose |
| 1555 | /// ins(%extracted_slice : tensor<?x2xf32>) |
| 1556 | /// outs(%empty : tensor<2x?xf32>) |
| 1557 | /// permutation = [1, 0] |
| 1558 | /// // InsertSliceOp |
| 1559 | /// %inserted_slice = tensor.insert_slice %transposed |
| 1560 | /// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1] |
| 1561 | /// : tensor<2x?xf32> into tensor<1x1x2x?xf32> |
| 1562 | /// ``` |
| 1563 | struct DecomposeOuterUnitDimsPackOpPattern |
| 1564 | : public OpRewritePattern<linalg::PackOp> { |
| 1565 | using OpRewritePattern<linalg::PackOp>::OpRewritePattern; |
| 1566 | LogicalResult matchAndRewrite(linalg::PackOp packOp, |
| 1567 | PatternRewriter &rewriter) const override; |
| 1568 | }; |
| 1569 | |
| 1570 | /// Rewrites a linalg::UnPackOp into a sequence of rank-reduced |
| 1571 | /// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp |
| 1572 | /// |
| 1573 | /// Requires that all the outer dims of the input linalg::PackOp are 1. |
| 1574 | /// |
| 1575 | /// Before: |
| 1576 | /// ``` |
| 1577 | /// %packed = linalg.unpack %input |
| 1578 | /// inner_dims_pos = [1, 0] |
| 1579 | /// inner_tiles = [2, 8] |
| 1580 | /// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32> |
| 1581 | /// ``` |
| 1582 | /// |
| 1583 | /// After: |
| 1584 | /// ``` |
| 1585 | /// // Rank-reduced extract to obtain the tile |
| 1586 | /// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1] |
| 1587 | /// : tensor<1x1x2x8xf32> to tensor<2x8xf32> |
| 1588 | /// // EmptyOp + TransposeOp |
| 1589 | /// %init = tensor.empty() : tensor<8x2xf32> |
| 1590 | /// %transposed = linalg.transpose |
| 1591 | /// ins(%extracted_slice : tensor<2x8xf32>) |
| 1592 | /// outs(%0 : tensor<8x2xf32>) permutation = [1, 0] |
| 1593 | /// // Extract a slice matching the specified output size |
| 1594 | /// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1] |
| 1595 | /// : tensor<8x2xf32> to tensor<5x1xf32> |
| 1596 | /// ``` |
| 1597 | struct DecomposeOuterUnitDimsUnPackOpPattern |
| 1598 | : public OpRewritePattern<linalg::UnPackOp> { |
| 1599 | using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern; |
| 1600 | LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, |
| 1601 | PatternRewriter &rewriter) const override; |
| 1602 | }; |
| 1603 | |
| 1604 | /// Match and rewrite for the pattern: |
| 1605 | /// ``` |
| 1606 | /// %alloc = ... |
| 1607 | /// [optional] %view = memref.view %alloc ... |
| 1608 | /// %subView = subview %allocOrView ... |
| 1609 | /// [optional] linalg.fill(%allocOrView, %cst) ... |
| 1610 | /// ... |
| 1611 | /// memref.copy(%in, %subView) ... |
| 1612 | /// vector.transfer_read %allocOrView[...], %cst ... |
| 1613 | /// ``` |
| 1614 | /// into |
| 1615 | /// ``` |
| 1616 | /// [unchanged] %alloc = ... |
| 1617 | /// [unchanged] [optional] %view = memref.view %alloc ... |
| 1618 | /// [unchanged] [unchanged] %subView = subview %allocOrView ... |
| 1619 | /// ... |
| 1620 | /// vector.transfer_read %in[...], %cst ... |
| 1621 | /// ``` |
| 1622 | /// Where there is no interleaved use between memref.copy and transfer_read as |
| 1623 | /// well as no interleaved use between linalg.fill and memref.copy (if |
| 1624 | /// linalg.fill is specified). |
| 1625 | /// This is a custom rewrite to forward partial reads (with optional fills) to |
| 1626 | /// vector.transfer_read. |
| 1627 | struct LinalgCopyVTRForwardingPattern |
| 1628 | : public OpRewritePattern<vector::TransferReadOp> { |
| 1629 | using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
| 1630 | |
| 1631 | LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, |
| 1632 | PatternRewriter &rewriter) const override; |
| 1633 | }; |
| 1634 | |
| 1635 | /// Match and rewrite for the pattern: |
| 1636 | /// ``` |
| 1637 | /// %alloc = ... |
| 1638 | /// [optional] %view = memref.view %alloc ... |
| 1639 | /// %subView = subview %allocOrView... |
| 1640 | /// ... |
| 1641 | /// vector.transfer_write %..., %allocOrView[...] |
| 1642 | /// memref.copy(%subView, %out) |
| 1643 | /// ``` |
| 1644 | /// into |
| 1645 | /// ``` |
| 1646 | /// [unchanged] %alloc = ... |
| 1647 | /// [unchanged] [optional] %view = memref.view %alloc ... |
| 1648 | /// [unchanged] %subView = subview %allocOrView... |
| 1649 | /// ... |
| 1650 | /// vector.transfer_write %..., %out[...] |
| 1651 | /// ``` |
| 1652 | /// Where there is no interleaved use between transfer_write and memref.copy. |
| 1653 | /// This is a custom rewrite to forward partial writes to |
| 1654 | /// vector.transfer_write. |
| 1655 | struct LinalgCopyVTWForwardingPattern |
| 1656 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 1657 | using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
| 1658 | |
| 1659 | LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, |
| 1660 | PatternRewriter &rewriter) const override; |
| 1661 | }; |
| 1662 | |
| 1663 | /// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)). |
| 1664 | struct |
| 1665 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
| 1666 | /// A function to control pattern application and rewrite logic. |
| 1667 | /// |
| 1668 | /// The function will be given the slice op and should return: |
| 1669 | /// - std::nullopt: to fail the match and not apply the pattern; |
| 1670 | /// - true: to apply the pattern with zero slice guard; |
| 1671 | /// - false: to apply the pattern without zero slice guard. |
| 1672 | /// |
| 1673 | /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice |
| 1674 | /// guard. |
| 1675 | using = std::function<std::optional<bool>(tensor::ExtractSliceOp)>; |
| 1676 | |
| 1677 | (MLIRContext *context, |
| 1678 | ControlFn controlFn = nullptr, |
| 1679 | PatternBenefit benefit = 1) |
| 1680 | : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} |
| 1681 | |
| 1682 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
| 1683 | PatternRewriter &rewriter) const override; |
| 1684 | |
| 1685 | private: |
| 1686 | ControlFn ; |
| 1687 | }; |
| 1688 | |
| 1689 | //===----------------------------------------------------------------------===// |
| 1690 | // Populate functions. |
| 1691 | //===----------------------------------------------------------------------===// |
| 1692 | |
| 1693 | /// Canonicalization patterns relevant to apply after tiling patterns. These |
| 1694 | /// are applied automatically by the tiling pass but need to be applied |
| 1695 | /// manually when tiling is called programmatically. |
| 1696 | RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); |
| 1697 | void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns); |
| 1698 | |
| 1699 | /// Linalg generalization patterns |
| 1700 | |
| 1701 | /// Populates `patterns` with patterns to convert spec-generated named ops to |
| 1702 | /// linalg.generic ops. |
| 1703 | void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns); |
| 1704 | |
| 1705 | /// Populates `patterns` with patterns to convert linalg.generic ops to named |
| 1706 | /// ops where possible. A linalg.generic can represent wide range and complex |
| 1707 | /// computations for which equivalent linalg named op may not exist e.g. |
| 1708 | /// linalg.generic that takes a tensor and computes a polynomial such as: |
| 1709 | /// p(x) = an*x^n + ... + a1x + a0 |
| 1710 | /// There is no equivalent named op to convert to. Many such cases exist. |
| 1711 | void populateLinalgGenericOpsSpecializationPatterns( |
| 1712 | RewritePatternSet &patterns); |
| 1713 | |
| 1714 | /// Populates `patterns` with patterns that fold operations like |
| 1715 | /// `linalg.transform` into elementwise op map. |
| 1716 | void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns); |
| 1717 | |
| 1718 | /// Linalg decompose convolutions patterns |
| 1719 | |
| 1720 | /// Populates patterns to decompose high-D convolution ops into low-D ones. |
| 1721 | /// This is a step in progressive lowering for convolution ops, afterwards we |
| 1722 | /// can vectorize the low-D convolution ops. |
| 1723 | void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, |
| 1724 | PatternBenefit benefit = 1); |
| 1725 | |
| 1726 | /// Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g. |
| 1727 | /// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all |
| 1728 | /// outer dims to be unit. |
| 1729 | void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns); |
| 1730 | |
| 1731 | /// Populates patterns to decompose tensor.pad into e.g. |
| 1732 | /// tensor.empty, linalg.fill, tensor.insert_slice. |
| 1733 | void populateDecomposePadPatterns(RewritePatternSet &patterns); |
| 1734 | |
| 1735 | /// Populates patterns to transform linalg.conv_2d_xxx operations into |
| 1736 | /// linalg.generic (for img2col packing) and linalg.matmul. |
| 1737 | /// \see rewriteInIm2Col for more details. |
| 1738 | void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns); |
| 1739 | |
| 1740 | /// Populates `patterns` with patterns that vectorize tensor.pad. |
| 1741 | /// These patterns are meant to apply in a complementary fashion. Benefits |
| 1742 | /// are used to encode a certain ordering of pattern application. To avoid |
| 1743 | /// scattering magic constants throughout the code base, the patterns must be |
| 1744 | /// added with this function. `baseBenefit` can be used to offset the benefit |
| 1745 | /// of all tensor::PadOp vectorization patterns by a certain value. |
| 1746 | void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, |
| 1747 | PatternBenefit baseBenefit = 1); |
| 1748 | |
| 1749 | /// Populate patterns for splitting a `LinalgOp` with multiple statements within |
| 1750 | /// its payload into multiple `GenericOp` that have a single statement. |
| 1751 | /// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments |
| 1752 | /// and results from the generated decomposed ops. This is default `true` since |
| 1753 | /// the core decomposition patterns relies on these clean up patterns. It is set |
| 1754 | /// to false only for testing purposes. |
| 1755 | void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, |
| 1756 | bool removeDeadArgsAndResults = true); |
| 1757 | |
| 1758 | /// Populate patterns that convert non-destination-style ops to destination |
| 1759 | /// style ops. |
| 1760 | void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns); |
| 1761 | |
| 1762 | /// Populate patterns for vectorizing low-D convolution ops. This is a step in |
| 1763 | /// progressive lowering for convolution ops, it assume high-D convolution ops |
| 1764 | /// were decomposed previously. |
| 1765 | void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, |
| 1766 | PatternBenefit benefit = 1); |
| 1767 | |
| 1768 | /// Populate patterns that convert `ElementwiseMappable` ops to linalg |
| 1769 | /// parallel loops. |
| 1770 | void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); |
| 1771 | |
| 1772 | /// Populate patterns that are only useful in the context of sparse tensors. |
| 1773 | void populateSparseTensorRewriting(RewritePatternSet &patterns); |
| 1774 | |
| 1775 | /// Function type which is used to control when to stop fusion. It is expected |
| 1776 | /// that OpOperand is not modified in the callback. The OpOperand is not marked |
| 1777 | /// as const to allow callers to use non-const methods. |
| 1778 | using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>; |
| 1779 | |
| 1780 | /// Patterns for fusing linalg operation on tensors. |
| 1781 | |
| 1782 | /// Pattern to fuse `linalg.generic` -> `linalg.generic` operations |
| 1783 | /// when both operations are fusable elementwise operations. |
| 1784 | void populateElementwiseOpsFusionPatterns( |
| 1785 | RewritePatternSet &patterns, |
| 1786 | const ControlFusionFn &controlElementwiseOpFusion); |
| 1787 | |
| 1788 | /// Function type which is used to control propagation of linalg.pack/unpack |
| 1789 | /// ops. |
| 1790 | using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>; |
| 1791 | |
| 1792 | /// Patterns to bubble up or down data layout ops across other operations. |
| 1793 | void populateDataLayoutPropagationPatterns( |
| 1794 | RewritePatternSet &patterns, |
| 1795 | const ControlPropagationFn &controlPackUnPackPropagation); |
| 1796 | |
| 1797 | /// Pattern to remove dead operands and results of `linalg.generic` operations. |
| 1798 | /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`. |
| 1799 | void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); |
| 1800 | |
| 1801 | /// Patterns to promote inputs to outputs and remove unused inputs of |
| 1802 | /// `linalg.generic` ops. |
| 1803 | void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns); |
| 1804 | |
| 1805 | /// Function type to control generic op dimension collapsing. It is expected |
| 1806 | /// to return an array of `ReassociationIndices` representing dimensions that |
| 1807 | /// should be merged. |
| 1808 | using GetCollapsableDimensionsFn = |
| 1809 | std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>; |
| 1810 | |
| 1811 | /// Pattern to collapse dimensions in a linalg.generic op. This will collapse |
| 1812 | /// tensor operands when needed and expand back the result tensors. |
| 1813 | void populateCollapseDimensions( |
| 1814 | RewritePatternSet &patterns, |
| 1815 | const GetCollapsableDimensionsFn &controlCollapseDimensions); |
| 1816 | |
| 1817 | /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its |
| 1818 | /// producer (consumer) generic operation by expanding the dimensionality of the |
| 1819 | /// loop in the generic op. |
| 1820 | void populateFoldReshapeOpsByExpansionPatterns( |
| 1821 | RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); |
| 1822 | |
| 1823 | /// Patterns to fold an expanding tensor.expand_shape operation with its |
| 1824 | /// producer generic operation by collapsing the dimensions of the generic op. |
| 1825 | void populateFoldReshapeOpsByCollapsingPatterns( |
| 1826 | RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); |
| 1827 | |
| 1828 | /// Patterns to constant fold Linalg operations. |
| 1829 | void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, |
| 1830 | const ControlFusionFn &controlFn); |
| 1831 | |
| 1832 | /// Pattern to replace `linalg.add` when destination passing on a contraction op |
| 1833 | /// suffices for achieving the sum. |
| 1834 | void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); |
| 1835 | |
| 1836 | /// Pattern to fuse a `tensor.pad` operation with the producer of its source, |
| 1837 | /// if the producer is a `linalg` operation with all parallel iterator types. |
| 1838 | void populateFuseTensorPadWithProducerLinalgOpPatterns( |
| 1839 | RewritePatternSet &patterns); |
| 1840 | |
| 1841 | /// Patterns to convert from one named op to another. These can be seen as |
| 1842 | /// canonicalizations of named ops into another named op. |
| 1843 | void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); |
| 1844 | |
| 1845 | /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on |
| 1846 | /// tensors via reassociative reshape ops. |
| 1847 | void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, |
| 1848 | ControlDropUnitDims &options); |
| 1849 | |
| 1850 | /// A pattern that converts init operands to input operands. |
| 1851 | void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns); |
| 1852 | |
| 1853 | /// Patterns that are used to inline constant operands into linalg generic ops. |
| 1854 | void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); |
| 1855 | |
| 1856 | /// Patterns that are used to bubble up extract slice op above linalg op. |
| 1857 | void (RewritePatternSet &patterns); |
| 1858 | |
| 1859 | /// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into |
| 1860 | /// linalg.fill(%cst, tensor.extract_slice(%init)). |
| 1861 | void (RewritePatternSet &patterns); |
| 1862 | |
| 1863 | /// Add patterns to make explicit broadcasts and transforms in the |
| 1864 | /// input operands of a genericOp. |
| 1865 | void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns); |
| 1866 | |
| 1867 | /// Patterns to apply `splitReduction` below. |
| 1868 | void populateSplitReductionPattern( |
| 1869 | RewritePatternSet &patterns, |
| 1870 | const ControlSplitReductionFn &controlSplitReductionFn, |
| 1871 | bool useAlloc = false); |
| 1872 | |
| 1873 | /// Patterns to convert Linalg matmul ops to transposed variants. |
| 1874 | void populateTransposeMatmulPatterns(RewritePatternSet &patterns, |
| 1875 | bool transposeLHS = true); |
| 1876 | |
| 1877 | /// Patterns to block pack Linalg matmul ops. |
| 1878 | void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, |
| 1879 | const ControlBlockPackMatmulFn &controlFn); |
| 1880 | |
| 1881 | /// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). |
| 1882 | void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, |
| 1883 | int64_t r); |
| 1884 | |
| 1885 | /// Patterns to decompose Winograd operators. |
| 1886 | void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); |
| 1887 | |
| 1888 | /// Adds patterns that reduce the rank of named contraction ops that have |
| 1889 | /// unit dimensions in the operand(s) by converting to a sequence of |
| 1890 | /// `collapse_shape`, |
| 1891 | /// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For |
| 1892 | /// example a `linalg.batch_matmul` with unit batch size will convert to |
| 1893 | /// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will |
| 1894 | /// convert to a `linalg.dot`. |
| 1895 | void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); |
| 1896 | |
| 1897 | /// Populates `patterns` with patterns that fold operations like `tensor.pad` |
| 1898 | /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations |
| 1899 | /// respectively. |
| 1900 | void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns); |
| 1901 | |
| 1902 | /// Populates `patterns` with patterns that fold operations like `linalg.pack` |
| 1903 | /// and `linalg.unpack` into `tensor.empty`. |
| 1904 | void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns); |
| 1905 | |
| 1906 | /// Populates `patterns` with patterns that simplify `tensor.pack` and |
| 1907 | /// `tensor.unpack` operations. |
| 1908 | void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns); |
| 1909 | |
| 1910 | } // namespace linalg |
| 1911 | } // namespace mlir |
| 1912 | |
| 1913 | #endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H |
| 1914 | |