1//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
10#define MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
11
12#include <utility>
13
14#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Linalg/Utils/Utils.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/Utils/Utils.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22#include "mlir/Dialect/X86Vector/Transforms.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Interfaces/TilingInterface.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallSet.h"
28
29namespace mlir {
30namespace bufferization {
31class AllocTensorOp;
32class OneShotAnalysisState;
33class BufferizationState;
34} // namespace bufferization
35
36namespace linalg {
37
38class LinalgOp;
39
40//===----------------------------------------------------------------------===//
41// Utils.
42//===----------------------------------------------------------------------===//
43
44/// Return vector::CombiningKind for the given op.
45std::optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
46
47//===----------------------------------------------------------------------===//
48// Bufferization-related transforms.
49//===----------------------------------------------------------------------===//
50
51struct BufferizeToAllocationOptions {
52 enum class AllocOp { MemrefAlloc = 0, MemrefAlloca = 1 };
53 AllocOp allocOp = AllocOp::MemrefAlloc;
54
55 enum class MemcpyOp {
56 MaterializeInDestination = 0,
57 MemrefCopy = 1,
58 LinalgCopy = 2
59 };
60 MemcpyOp memcpyOp = MemcpyOp::MaterializeInDestination;
61
62 /// If set to "true", only the destination tensor operands are bufferized to
63 /// a new allocation (and wrapped in "bufferization.to_tensor"), but not the
64 /// targeted op itself.
65 bool bufferizeDestinationOnly = false;
66
67 /// If set to "true", a `memref.dealloc` operation will be emitted for each
68 /// allocated buffer. Otherwise, the memory is leaked, which is useful if
69 /// the buffer deallocation pipeline should be run after bufferization is
70 /// done.
71 bool emitDealloc = false;
72};
73
74/// Materialize a buffer allocation for the given tensor.pad op and lower the
75/// op to linalg.fill/linalg.generic + bufferization.materialize_in_destination.
76/// E.g.:
77///
78/// %0 = tensor.pad low[%l] high[%h] %t ...
79///
80/// is lowered to:
81///
82/// %alloc = memref.alloc
83/// linalg.fill ... outs(%alloc)
84/// %subview = memref.subview %alloc [%l] [...] [1]
85/// bufferization.materialize_in_destination %t in %subview
86/// %0 = bufferization.to_tensor %alloc restrict writable
87///
88/// In addition to rewriting the IR as shown above, this function returns the
89/// newly allocated buffer. The `insertionPoint` parameter can be used to
90/// specify a custom insertion point for the buffer allocation.
91Value bufferizeToAllocation(RewriterBase &rewriter,
92 const BufferizeToAllocationOptions &options,
93 tensor::PadOp padOp, Attribute memorySpace = {},
94 Operation *insertionPoint = nullptr);
95
96/// Materialize a buffer allocation for the given vector.mask op and bufferize
97/// the op, including its region. E.g.:
98///
99/// %0 = vector.mask {
100/// vector.transfer_write %v, %t : vector<16xf32>, tensor<?xf32>
101/// } : vector<16xi1> -> tensor<?xf32>
102///
103/// is lowered to:
104///
105/// %alloc = memref.alloc
106/// bufferization.materialize_in_destination %t in %subview
107/// vector.mask {
108/// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32>
109/// } : vector<16xi1>
110/// %0 = bufferization.to_tensor %alloc restrict writable
111///
112/// In addition to rewriting the IR as shown above, this function returns the
113/// newly allocated buffer. The `insertionPoint` parameter can be used to
114/// specify a custom insertion point for the buffer allocation.
115Value bufferizeToAllocation(RewriterBase &rewriter,
116 const BufferizeToAllocationOptions &options,
117 vector::MaskOp maskOp, Attribute memorySpace = {},
118 Operation *insertionPoint = nullptr);
119
120/// Materialize a buffer allocation for the given bufferization.alloc_tensor op
121/// and lower the op to memref.alloc + memref.tensor_store.
122///
123/// In addition to rewriting the IR, this function returns the newly allocated
124/// buffer. The `insertionPoint` parameter can be used to specify a custom
125/// insertion point for the buffer allocation.
126Value bufferizeToAllocation(RewriterBase &rewriter,
127 const BufferizeToAllocationOptions &options,
128 bufferization::AllocTensorOp allocTensorOp,
129 Attribute memorySpace = {},
130 Operation *insertionPoint = nullptr);
131
132/// Bufferize the given op with tensor semantics and materialize the result in
133/// a newly allocated buffer.
134///
135/// Only bufferizable ops that bufferize to a memory write or have an
136/// aliasing OpOperand (and do not themselves bufferize to an allocation) are
137/// supported. They are bufferized using their BufferizableOpInterface
138/// implementation.
139///
140/// Selected ops that bufferize to an allocation (or need special handling) are
141/// also supported:
142/// - tensor.pad
143/// - vector.mask
144///
145/// This function returns the newly allocated buffer. The `insertionPoint`
146/// parameter can be used to specify a custom insertion point for the buffer
147/// allocation.
148Value bufferizeToAllocation(RewriterBase &rewriter,
149 const BufferizeToAllocationOptions &options,
150 Operation *op, Attribute memorySpace = {},
151 Operation *insertionPoint = nullptr);
152
153/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a
154/// LinalgOp. This transforms looks for LinalgOps that have an unused output
155/// operand and an input operand that is rooted in a tensor::EmptyOp. The
156/// tensor::EmptyOp uses are replaced with the output operand and the two
157/// operands of the LinalgOp are swapped.
158///
159/// Example:
160/// %0 = tensor.empty()
161/// %1 = linalg.matmul ins(...) outs(%0)
162/// %2 = linalg.generic ins(%1) outs(%dest) {
163/// ^bb0(%in: f32, %out: f32):
164/// // out not used
165/// }
166///
167/// The IR is transformed as follows:
168/// %0 = tensor.empty()
169/// %1 = linalg.matmul ins(...) outs(%dest)
170/// %2 = linalg.generic ins(%0) outs(%1) {
171/// ^bb0(%in: f32, %out: f32):
172/// // Use %out instead of %in
173/// }
174///
175/// The "ins" operand has no uses inside the body of the LinalgOp and can be
176/// folded away with existing cleanup patterns. Afterwards, the tensor::EmptyOp
177/// can also fold away.
178LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(
179 RewriterBase &rewriter, Operation *op,
180 bufferization::OneShotAnalysisState &state);
181
182//===----------------------------------------------------------------------===//
183// Structs that configure the behavior of various transformations.
184//===----------------------------------------------------------------------===//
185
186using TileSizeComputationFunction =
187 std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
188
189struct LinalgTilingOptions {
190 /// Computation function that returns the tile sizes for each operation.
191 /// Delayed construction of constant tile sizes should occur to interoperate
192 /// with folding.
193 TileSizeComputationFunction tileSizeComputationFunction = nullptr;
194
195 LinalgTilingOptions &
196 setTileSizeComputationFunction(TileSizeComputationFunction fun) {
197 tileSizeComputationFunction = std::move(fun);
198 return *this;
199 }
200 /// Set the `tileSizeComputationFunction` to return the values `ts`. The
201 /// values must not fold away when tiling. Otherwise, use a more robust
202 /// `tileSizeComputationFunction`.
203 LinalgTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
204 tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
205 return *this;
206 }
207 /// Convenience function to set the `tileSizeComputationFunction` to a
208 /// function that computes tile sizes at the point they are needed. Allows
209 /// proper interaction with folding.
210 LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
211
212 /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions.
213 /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together.
214 LinalgTilingOptions &scalarizeDynamicDims();
215
216 /// The interchange vector to reorder the tiled loops.
217 SmallVector<unsigned, 4> interchangeVector = {};
218
219 LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
220 interchangeVector.assign(in_start: interchange.begin(), in_end: interchange.end());
221 return *this;
222 }
223
224 /// The type of tile loops to generate.
225 LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops;
226
227 LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) {
228 loopType = lt;
229 return *this;
230 }
231
232 /// When specified, specifies distribution of generated tile loops to
233 /// processors.
234 std::optional<LinalgLoopDistributionOptions> distribution;
235
236 LinalgTilingOptions &
237 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
238 distribution = std::move(distributionOptions);
239 return *this;
240 }
241
242 /// Specification markers of how to distribute the `linalg.tiled_loop`.
243 SmallVector<StringRef, 2> distributionTypes = {};
244
245 LinalgTilingOptions &setDistributionTypes(ArrayRef<StringRef> types) {
246 distributionTypes.assign(in_start: types.begin(), in_end: types.end());
247 return *this;
248 }
249
250 /// Peel the specified loops.
251 SmallVector<int64_t> peeledLoops;
252
253 LinalgTilingOptions &setPeeledLoops(ArrayRef<int64_t> loops) {
254 peeledLoops.clear();
255 peeledLoops.append(in_start: loops.begin(), in_end: loops.end());
256 return *this;
257 }
258};
259
260struct LinalgTilingAndFusionOptions {
261 /// Tile sizes used to tile the root operation.
262 SmallVector<int64_t> tileSizes;
263 LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) {
264 tileSizes.assign(in_start: ts.begin(), in_end: ts.end());
265 return *this;
266 }
267 /// Tile interchange used to permute the tile loops.
268 SmallVector<int64_t> tileInterchange;
269 /// When specified, specifies distribution of generated tile loops to
270 /// processors.
271 std::optional<LinalgLoopDistributionOptions> tileDistribution;
272 LinalgTilingAndFusionOptions &
273 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
274 tileDistribution = std::move(distributionOptions);
275 return *this;
276 }
277};
278
279struct LinalgPaddingOptions {
280 /// A padding value for every operand.
281 SmallVector<Attribute> paddingValues;
282 LinalgPaddingOptions &setPaddingValues(ArrayRef<Attribute> pv) {
283 paddingValues.assign(in_start: pv.begin(), in_end: pv.end());
284 return *this;
285 }
286 /// A list of iterator dimensions to pad.
287 SmallVector<int64_t> paddingDimensions;
288 LinalgPaddingOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
289 paddingDimensions.assign(in_start: pd.begin(), in_end: pd.end());
290 return *this;
291 }
292 /// A list of multiples to which each padding dimension should be padded to.
293 std::optional<SmallVector<int64_t>> padToMultipleOf;
294 LinalgPaddingOptions &setPadToMultipleOf(ArrayRef<int64_t> m) {
295 padToMultipleOf.emplace(args: m.begin(), args: m.end());
296 return *this;
297 }
298 /// A flag for every operand to mark the PadOp as nofold which enables
299 /// packing for statically shaped operands.
300 SmallVector<bool> nofoldFlags;
301 LinalgPaddingOptions &setNofoldFlags(ArrayRef<bool> pp) {
302 nofoldFlags.assign(in_start: pp.begin(), in_end: pp.end());
303 return *this;
304 }
305 /// A number of loops to hoist the PadOp out for every operand.
306 SmallVector<int64_t> hoistPaddings;
307 LinalgPaddingOptions &setHoistPaddings(ArrayRef<int64_t> hp) {
308 hoistPaddings.assign(in_start: hp.begin(), in_end: hp.end());
309 return *this;
310 }
311 /// A permutation vector for every operand used to transpose the packed
312 /// PadOp results.
313 SmallVector<SmallVector<int64_t>> transposePaddings;
314 LinalgPaddingOptions &
315 setTransposePaddings(ArrayRef<SmallVector<int64_t>> tp) {
316 transposePaddings.assign(in_start: tp.begin(), in_end: tp.end());
317 return *this;
318 }
319 enum class CopyBackOp : int8_t {
320 None = 0,
321 BufferizationMaterializeInDestination = 1,
322 LinalgCopy = 2
323 };
324 /// The op to be used for copying the padded result to the original
325 /// destination tensor.
326 CopyBackOp copyBackOp = CopyBackOp::BufferizationMaterializeInDestination;
327 LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) {
328 copyBackOp = op;
329 return *this;
330 }
331};
332
333/// Callback function type used to perform the allocation for the promoted
334/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
335/// smallest constant value for the size of the buffer needed for each
336/// dimension. If that is not possible, contains the dynamic size of the
337/// subview. The call back should return the buffer to use.
338using AllocBufferCallbackFn = std::function<std::optional<Value>(
339 OpBuilder &b, memref::SubViewOp subView,
340 ArrayRef<Value> boundingSubViewSize, DataLayout &layout)>;
341
342/// Callback function type used to deallocate the buffers used to hold the
343/// promoted subview.
344using DeallocBufferCallbackFn =
345 std::function<LogicalResult(OpBuilder &b, Value buffer)>;
346
347/// Callback function type used to insert copy from original subview to
348/// subview of the promoted region for the read operands/subview of promoted
349/// region to original subview for the results. The copy has to happen from
350/// `src` to `dst`.
351using CopyCallbackFn =
352 std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>;
353
354struct LinalgPromotionOptions {
355 /// Indices of subViews to promote. If `std::nullopt`, try to promote all
356 /// operands.
357 std::optional<DenseSet<unsigned>> operandsToPromote;
358 LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) {
359 operandsToPromote = DenseSet<unsigned>();
360 operandsToPromote->insert_range(R&: operands);
361 return *this;
362 }
363 /// If ith element of `useFullTiles` is true the full view should be used
364 /// for the promoted buffer of the ith operand in `operandsToPromote`.
365 /// Otherwise the partial view will be used. The decision is defaulted to
366 /// `useFullTileBuffersDefault` when `useFullTileBuffers` is std::nullopt and
367 /// for operands missing from `useFullTileBuffers`.
368 std::optional<llvm::SmallBitVector> useFullTileBuffers;
369 LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) {
370 unsigned size = useFullTiles.size();
371 llvm::SmallBitVector tmp(size, false);
372 for (unsigned i = 0; i < size; ++i)
373 tmp[i] = useFullTiles[i];
374 useFullTileBuffers = tmp;
375 return *this;
376 }
377 /// If true all operands unspecified by `useFullTileBuffers` will use the
378 /// full view, otherwise the partial view.
379 bool useFullTileBuffersDefault = false;
380 LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) {
381 useFullTileBuffersDefault = use;
382 return *this;
383 }
384 /// Alignment of promoted buffer. If `std::nullopt` do not specify alignment.
385 std::optional<unsigned> alignment;
386 LinalgPromotionOptions &setAlignment(unsigned align) {
387 alignment = align;
388 return *this;
389 }
390 /// Memory space of promoted buffer. If `std::nullopt` do not specify memory
391 /// space.
392 std::optional<Attribute> memorySpace;
393 LinalgPromotionOptions &setMemorySpace(Attribute memorySpc) {
394 memorySpace = memorySpc;
395 return *this;
396 }
397 /// Use alloca with the default allocation scheme.
398 bool useAlloca = false;
399 LinalgPromotionOptions &setUseAlloca(bool use) {
400 useAlloca = use;
401 return *this;
402 }
403 /// Callback function to do the allocation of the promoted buffer. If
404 /// std::nullopt, then the default allocation scheme of allocating a
405 /// memref<?xi8> buffer followed by a view operation is used.
406 std::optional<AllocBufferCallbackFn> allocationFn;
407 std::optional<DeallocBufferCallbackFn> deallocationFn;
408 LinalgPromotionOptions &
409 setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn,
410 DeallocBufferCallbackFn const &deallocFn) {
411 allocationFn = allocFn;
412 deallocationFn = deallocFn;
413 return *this;
414 }
415 /// Callback function to do the copy of data to and from the promoted
416 /// subview. If std::nullopt then a memref.copy is used.
417 std::optional<CopyCallbackFn> copyInFn;
418 std::optional<CopyCallbackFn> copyOutFn;
419 LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const &copyIn,
420 CopyCallbackFn const &copyOut) {
421 copyInFn = copyIn;
422 copyOutFn = copyOut;
423 return *this;
424 }
425};
426
427/// Split Reduction options.
428struct SplitReductionOptions {
429 // Ratio used to split the reduction dimension. If the ratio is <= 1,
430 // nothing will be done.
431 int64_t ratio = 0;
432 // Index where the extra dimension is added to the intermediate tensor
433 // shape.
434 unsigned index = 0;
435 // If the inner dimension after splitting is parallel or reduction.
436 bool innerParallel = false;
437};
438
439/// Function signature to control reduction splitting. This returns
440/// `SplitReductionOptions`.
441// TODO: don't use unsigned unless doing bit manipulation.
442using ControlSplitReductionFn =
443 std::function<SplitReductionOptions(LinalgOp op)>;
444
445//===----------------------------------------------------------------------===//
446// Preconditions that ensure the corresponding transformation succeeds and can
447// be applied as a rewrite pattern.
448//===----------------------------------------------------------------------===//
449
450/// Return true if two `linalg.generic` operations with producer/consumer
451/// relationship through `fusedOperand` can be fused using elementwise op
452/// fusion.
453bool areElementwiseOpsFusable(OpOperand *fusedOperand);
454
455/// Promote memref.subviews feeding linalg-on-buffers operations.
456LogicalResult promoteSubviewsPrecondition(Operation *op,
457 LinalgPromotionOptions options);
458
459/// Return success if the operation can be vectorized.
460LogicalResult vectorizeOpPrecondition(Operation *op,
461 ArrayRef<int64_t> inputVectorSizes = {},
462 ArrayRef<bool> inputScalableVecDims = {},
463 bool vectorizeNDExtract = false,
464 bool flatten1DDepthwiseConv = false);
465
466//===----------------------------------------------------------------------===//
467// Transformations exposed as functional-style API calls.
468//===----------------------------------------------------------------------===//
469
470using LinalgLoops = SmallVector<Operation *, 4>;
471
472/// Transformation to drop unit-extent dimensions from `linalg.generic`
473/// operations.
474struct ControlDropUnitDims {
475 enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
476
477 RankReductionStrategy rankReductionStrategy =
478 RankReductionStrategy::ReassociativeReshape;
479
480 using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>;
481 ControlFnTy controlFn = [](Operation *op) {
482 if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) {
483 return llvm::to_vector(llvm::seq<unsigned>(0, genericOp.getNumLoops()));
484 }
485 if (auto padOp = dyn_cast_or_null<tensor::PadOp>(op)) {
486 return llvm::to_vector(
487 llvm::seq<unsigned>(0, padOp.getSourceType().getRank()));
488 }
489 return SmallVector<unsigned>{};
490 };
491};
492struct DropUnitDimsResult {
493 linalg::GenericOp resultOp;
494 SmallVector<Value> replacements;
495};
496FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
497 GenericOp genericOp,
498 const ControlDropUnitDims &options);
499
500/// Fuse two `linalg.generic` operations that have a producer-consumer
501/// relationship captured through `fusedOperand`. The method expects
502/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
503struct ElementwiseOpFusionResult {
504 Operation *fusedOp;
505 llvm::DenseMap<Value, Value> replacements;
506};
507FailureOr<ElementwiseOpFusionResult>
508fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
509
510/// Returns a set of indices of the producer's results which would
511/// be preserved after the fusion.
512/// * There is a chance that the implementation of the transformation does not
513/// agree with the result of this method. This function gives a prediction based
514/// on an optimized fusion.
515llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
516 GenericOp consumer,
517 OpOperand *fusedOperand);
518
519/// Try to peel and canonicalize loop `op` and return the new result.
520/// Also applies affine_min/max bounds simplification on the fly where relevant.
521// TODO: Add support for scf.parallel and affine.for loops.
522SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
523
524/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
525/// where relevant.
526void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
527
528/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
529/// to a static bounding box. The original `opToPad` is cloned and operates on
530/// the padded tensors.
531///
532/// * "options.padToMultipleOf" indicates that each padding dimension should be
533/// padded to the specified multiple.
534/// * Use "options.paddingValues" and "options.nofoldFlags" to set padding
535/// value and nofold attribute of the created tensor::PadOps, respectively.
536/// * The unpadded results (extracted slice of the cloned operation) are
537/// returned via `replacements`.
538/// * The tensor::PadOps are returned via `padOps`.
539/// * "options.copyBackOp" specifies the op type for copying back the unpadded
540/// result to the original destination tensor.
541LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
542 const LinalgPaddingOptions &options,
543 LinalgOp &paddedOp,
544 SmallVector<Value> &replacements,
545 SmallVector<tensor::PadOp> &padOps);
546
547namespace detail {
548
549/// Helper struct to hold the results of building a packing loop nest.
550struct PackingResult {
551 SmallVector<OpFoldResult> offsets, sizes, strides;
552 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
553 TransposeOp maybeTransposeOp;
554 tensor::PadOp hoistedPadOp;
555};
556
557/// Build the packing loop nest required to hoist `opToHoist` above
558/// `outermostEnclosingForOp`.
559/// The loop nest is built just before `outermostEnclosingForOp`.
560FailureOr<PackingResult>
561buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist,
562 scf::ForOp outermostEnclosingForOp,
563 ArrayRef<int64_t> transposeVector);
564
565} // namespace detail
566
567/// Mechanically hoist padding operations on tensors by `numLoops` into a new,
568/// generally larger tensor. This achieves packing of multiple padding ops into
569/// a larger tensor. On success, `opToHoist` is replaced by the cloned version
570/// in the packing loop so the caller can continue reasoning about the padding
571/// operation. If `transposeVector` is non-empty, hoist padding introduces a
572/// TransposeOp to transpose the padded tensor before inserting it into the
573/// packed tensor. A `transposeVector` can change the storage order of the
574/// padded tensor but does not change the order of the pack or compute loops.
575///
576/// TODO: In the future, we should consider rewriting as a linalg.pack after
577/// hoisting since this abstraction is now available.
578///
579/// Example in pseudo-mlir:
580/// =======================
581///
582/// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR.
583/// ```
584/// scf.for (%i, %j, %k)
585/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32>
586/// %0 = tensor.pad %st0 low[0, 0] high[...] {
587/// ^bb0( ... ):
588/// linalg.yield %pad
589/// } : tensor<?x?xf32> to tensor<4x8xf32>
590/// compute(%0)
591/// ```
592///
593/// IR resembling the following is produced:
594///
595/// ```
596/// scf.for (%i) {
597/// %packed_init = tensor.empty range(%j) : tensor<?x4x8xf32>
598/// %packed = scf.for (%k) iter_args(%p : %packed_init) {
599/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32>
600/// %0 = tensor.pad %st0 low[0, 0] high[...] {
601/// ^bb0( ... ):
602/// linalg.yield %pad
603/// } : tensor<?x?xf32> to tensor<4x8xf32>
604/// %1 = tensor.insert_slice %0 ...
605/// : tensor<4x8xf32> to tensor<?x4x8xf32>
606/// scf.yield %1: tensor<?x4x8xf32>
607/// } -> tensor<?x4x8xf32>
608/// scf.for (%j, %k) {
609/// %st0 = tensor.extract_slice %packed [%k, 0, 0][1, 4, 8][1, 1, 1] :
610/// tensor<?x4x8xf32> to tensor<4x8xf32>
611/// compute(%st0)
612/// }
613/// }
614/// ```
615FailureOr<Value>
616hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist,
617 int64_t numLoops, ArrayRef<int64_t> transposeVector,
618 tensor::PadOp &hoistedOp,
619 SmallVectorImpl<TransposeOp> &transposeOps);
620/// Calls into `hoistPaddingOnTensors` with a local IRRewriter.
621FailureOr<Value>
622hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops,
623 ArrayRef<int64_t> transposeVector,
624 tensor::PadOp &hoistedOp,
625 SmallVectorImpl<TransposeOp> &transposeOps);
626
627/// Apply padding and hoisting to `linalgOp` according to the configuration
628/// specified in `options`.
629FailureOr<LinalgOp> padAndHoistLinalgOp(RewriterBase &rewriter,
630 LinalgOp linalgOp,
631 const LinalgPaddingOptions &options);
632
633/// Split the given `op` into two parts along the given iteration space
634/// `dimension` at the specified `splitPoint`, and return the two parts.
635/// If the second part is statically known to be empty, do not create it
636/// and return nullptr instead. Error state is signalled by returning
637/// a pair of nullptrs.
638///
639/// For example, the following op:
640///
641/// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>)
642/// outs(%2 : tensor<128x64xf32>)
643///
644/// split along the first dimension at position 42 will result in:
645///
646/// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1]
647/// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1]
648/// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
649/// outs(%5 : tensor<42x64xf32>)
650/// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1]
651///
652/// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1]
653/// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1]
654/// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>)
655/// outs(%8 : tensor<86x64xf32>)
656/// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1]
657///
658/// Note that there is no simplification other than constant propagation applied
659/// to slice extraction and insertion.
660std::pair<TilingInterface, TilingInterface> splitOp(RewriterBase &rewriter,
661 TilingInterface op,
662 unsigned dimension,
663 OpFoldResult splitPoint);
664
665/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
666/// and permute the loop nest according to `interchangeVector`
667/// The permutation is expressed as a list of integers that specify
668/// the new ordering of the loop nest. The length of `interchangeVector`
669/// must be equal to the length of `tileSizes`.
670/// An empty vector is interpreted as the identity permutation and the
671/// transformation returns early.
672///
673/// Return a struct containing the tiled loops in the specified order
674/// and the cloned op if successful, std::nullopt otherwise.
675///
676/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
677/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
678/// integers, in the range 0..`tileSizes.size()` without duplications
679/// (i.e. `[1,1,2]` is an invalid permutation).
680struct TiledLinalgOp {
681 LinalgOp op;
682 SmallVector<Operation *, 8> loops;
683 SmallVector<Value, 4> tensorResults;
684};
685FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
686 const LinalgTilingOptions &options);
687
688/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
689/// the index accesses of `op`. This is an in-place transformation controlled
690/// by `interchangeVector`. An empty vector is interpreted as the identity
691/// permutation and the transformation returns early.
692///
693/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
694/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
695/// integers, in the range 0..`op.rank` without duplications
696/// (i.e. `[1,1,2]` is an invalid permutation).
697///
698/// Return failure if the permutation is not valid.
699FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
700 GenericOp genericOp,
701 ArrayRef<unsigned> interchangeVector);
702
703/// Create a GenericOp from the given named operation `linalgOp` and replace
704/// the given `linalgOp`.
705/// Return failure if `linalgOp` is a GenericOp or misses a region builder.
706FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
707 LinalgOp linalgOp);
708
709/// Create a namedOp from the given GenericOp and replace the GenericOp.
710/// Currently we can specialize only trivial linalg copy operations.
711FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
712 GenericOp genericOp);
713
714/// Create a new buffer using the `allocationFn` provided. The size of this
715/// buffer is the smallest constant bounding size along each dimension that
716/// can be computed for the size of the result of `subView`. Returns the
717/// allocated buffer as `fullLocalView` and the view that matches the size of
718/// the result of subview operation as `partialLocalView`.
719struct PromotionInfo {
720 Value fullLocalView;
721 Value partialLocalView;
722};
723FailureOr<PromotionInfo>
724promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
725 const AllocBufferCallbackFn &allocationFn,
726 DataLayout &layout);
727
728/// Promote the `subViews` into a new buffer allocated at the insertion point
729/// `b`. Promotion occurs in 3 steps:
730/// 1. Create a new buffer for a full tile (i.e. not clipped at the
731/// boundary).
732/// 2. Take a full view on the buffer.
733/// 3. Take a partial slice of the full view in step 2. and copy into it.
734///
735/// Return the modified linalg op (the modification happens in place) as well
736/// as all the copy ops created.
737FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
738 const LinalgPromotionOptions &options);
739
740/// Allocate the subview in the GPU workgroup memory.
741std::optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
742 memref::SubViewOp subview,
743 ArrayRef<Value> sizeBounds,
744 DataLayout &);
745
746/// In case of GPU group memory there is no need to deallocate.
747LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/);
748
749/// Create Memref copy operations and add gpu barrier guards before and after
750/// the copy operation to ensure data integrity.
751LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst);
752
753/// Allocate the subview in the GPU private memory.
754std::optional<Value> allocateGPUPrivateMemory(OpBuilder &builder,
755 memref::SubViewOp subview,
756 ArrayRef<Value> sizeBounds,
757 DataLayout &);
758
759/// Normal copy to between src and dst.
760LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst);
761
762/// In case of GPU private memory there is no need to deallocate since the
763/// memory is freed when going outside of the scope.
764LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
765
766/// Return true if there's dedicated logic in the Linalg Vectorizer to
767/// vectorize this Op, false otherwise.
768///
769/// Note that this helper merely implements a very high level check and that the
770/// vectorizer also requires various additional pre-conditions to be met for it
771/// to work (these are checked by the vectorizer itself).
772bool hasVectorizationImpl(Operation *);
773
774/// Emit a suitable vector form for an operation. If provided,
775/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
776/// must match the rank of the iteration space of the operation and the sizes
777/// must be smaller or equal than their counterpart interation space sizes, if
778/// static. `inputVectorShapes` also allows the vectorization of operations with
779/// dynamic shapes.
780LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
781 ArrayRef<int64_t> inputVectorSizes = {},
782 ArrayRef<bool> inputScalableVecDims = {},
783 bool vectorizeNDExtract = false,
784 bool flatten1DDepthwiseConv = false);
785
786/// Emit a suitable vector form for a Copy op with fully static shape.
787LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
788
789/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
790FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter,
791 LinalgOp linalgOp);
792
793/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
794FailureOr<LinalgLoops> linalgOpToParallelLoops(RewriterBase &rewriter,
795 LinalgOp linalgOp);
796
797/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
798FailureOr<LinalgLoops> linalgOpToAffineLoops(RewriterBase &rewriter,
799 LinalgOp linalgOp);
800
801/// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
802/// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument
803/// has one entry per surrounding loop. It uses zero as the convention that a
804/// particular loop is not tiled. This convention simplifies implementations
805/// by avoiding affine map manipulations. The returned ranges correspond to
806/// the loop ranges, in the proper order, that are tiled and for which new
807/// loops will be created. Also the function returns a map from loop indices
808/// of the LinalgOp to the corresponding non-empty range indices of newly
809/// created loops.
810using LoopIndexToRangeIndexMap = DenseMap<int, int>;
811std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
812makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
813 ArrayRef<OpFoldResult> allShapeSizes,
814 ArrayRef<OpFoldResult> allTileSizes);
815
816namespace detail {
817template <typename T>
818struct MultiSizeSpecificationBase {
819 /// Tile sizes.
820 T lowTileSize, highTileSize;
821 /// Number of tiles associated with each size.
822 T lowTripCount, highTripCount;
823};
824
825template <typename T>
826struct ContinuousTileSizeSpecificationBase {
827 /// Tile sizes.
828 SmallVector<T> tileSizes;
829 /// Number of tiles associated with each size.
830 SmallVector<T> tripCounts;
831};
832
833} // namespace detail
834
835/// A description of a multi-size tiling comprising tile sizes and numbers of
836/// tiles, expressed as Values which may or may not be constant. Multi-size
837/// currently means two-size.
838struct MultiSizeSpecification
839 : public detail::MultiSizeSpecificationBase<Value> {};
840struct StaticMultiSizeSpecification
841 : public detail::MultiSizeSpecificationBase<int64_t> {};
842
843struct ContinuousTileSizeSpecification
844 : public detail::ContinuousTileSizeSpecificationBase<Value> {};
845struct StaticContinuousTileSizeSpecification
846 : public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
847
848/// Emits the IR computing the multi-sized tiling specification with two tile
849/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
850/// that there exist numbers of tiles with these sizes that fully cover the
851/// given iteration space `dimension` of the structured `op`.
852///
853/// The computation is as follows:
854///
855/// b = originalTripCount floordiv sizeDivisor
856/// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor
857/// d = (b + t - 1) floordiv t
858/// s = (b floordiv d) * sizeDivisor
859/// v = b % d
860/// u = d - v
861///
862/// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of
863/// the corresponding tiles are `u` and `v`, respectively. Alternatively,
864///
865/// s * u + (s + sizeDivisor) * v == original size,
866/// where s mod sizeDivisor = 0.
867///
868/// Expects all values to be positive. In some cases with the target tile size
869/// sufficiently close to the dimension shape and non-unit divisor, it is
870/// impossible to compute such sizes. If `emitAssertion` is set, also emit the
871/// assertion that size computation succeeded.
872///
873/// Returns the specification consisting of both tile values and the number of
874/// tiles of each size.
875FailureOr<MultiSizeSpecification>
876computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
877 OpFoldResult targetSize, OpFoldResult divisor,
878 bool emitAssertions = true);
879FailureOr<StaticMultiSizeSpecification>
880computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
881 int64_t divisor);
882
883FailureOr<StaticContinuousTileSizeSpecification>
884computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
885 unsigned targetSize);
886FailureOr<ContinuousTileSizeSpecification>
887computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
888 unsigned dimension, OpFoldResult targetSize,
889 bool emitAssertions);
890
891/// Transformation information returned after reduction tiling.
892struct ForallReductionTilingResult {
893 /// The partial reduction tiled op generated.
894 SmallVector<Operation *> parallelTiledOps;
895 /// The final reduction operation merging all the partial reductions.
896 SmallVector<Operation *> mergeOps;
897 /// Initial values used for partial reductions.
898 SmallVector<Value> initialValues;
899 /// The `scf.forall` operation that iterate over the tiles.
900 scf::ForallOp loops;
901};
902
903/// Method to tile a reduction to parallel iterations computing partial
904/// reductions. After the loop all the partial reduction are merged into a final
905/// reduction. For example for the following sequence
906///
907/// ```mlir
908/// %0 = linalg.generic %in ["parallel", "reduction"]
909/// : tensor<7x9xf32> -> tensor<7xf32>
910/// ```
911///
912/// into:
913///
914/// ```mlir
915/// %0 = linalg.fill ... : tensor<7x4xf32>
916/// %1 = scf.forall (%iv) in (%c4) shared_outs(%arg0 = %0)
917/// -> (tensor<7x4xf32>) {
918/// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32>
919/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
920/// %4 = linalg.generic %2, %3 ["parallel", "reduction"]
921/// : tensor<7x?xf32> -> tensor<7xf32>
922/// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32>
923/// }
924/// %6 = linalg.generic %1 ["parallel", "reduction"]
925/// : tensor<7x4xf32> -> tensor<7xf32>
926/// ```
927FailureOr<ForallReductionTilingResult>
928tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op,
929 ArrayRef<OpFoldResult> numThreads,
930 ArrayRef<OpFoldResult> tileSizes = {},
931 std::optional<ArrayAttr> mapping = std::nullopt);
932
933/// All indices returned by IndexOp should be invariant with respect to
934/// tiling. Therefore, if an operation is tiled, we have to transform the
935/// indices accordingly, i.e. offset them by the values of the corresponding
936/// induction variables that are captured implicitly in the body of the op.
937///
938/// Example. `linalg.generic` before tiling:
939///
940/// #id_2d = (i, j) -> (i, j)
941/// #pointwise_2d_trait = {
942/// indexing_maps = [#id_2d, #id_2d],
943/// iterator_types = ["parallel", "parallel"]
944/// }
945/// linalg.generic #pointwise_2d_trait %operand, %result {
946/// ^bb0(%operand_in: f32, %result_in: f32):
947/// %i = linalg.index 0 : index
948/// %j = linalg.index 1 : index
949/// <some operations that use %i, %j>
950/// }: memref<50x100xf32>, memref<50x100xf32>
951///
952/// After tiling pass with tiles sizes 10 and 25:
953///
954/// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
955///
956/// %c1 = arith.constant 1 : index
957/// %c0 = arith.constant 0 : index
958/// %c25 = arith.constant 25 : index
959/// %c10 = arith.constant 10 : index
960/// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
961/// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
962/// scf.for %k = %c0 to operand_dim_0 step %c10 {
963/// scf.for %l = %c0 to operand_dim_1 step %c25 {
964/// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
965/// : memref<50x100xf32> to memref<?x?xf32, #strided>
966/// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1]
967/// : memref<50x100xf32> to memref<?x?xf32, #strided>
968/// linalg.generic pointwise_2d_trait %4, %5 {
969/// ^bb0(%operand_in: f32, %result_in: f32):
970/// %i = linalg.index 0 : index
971/// %j = linalg.index 1 : index
972/// // Indices `k` and `l` are implicitly captured in the body.
973/// %transformed_i = arith.addi %i, %k : index // index `i` is offset by
974/// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset
975/// by %l
976/// // Every use of %i, %j is replaced with %transformed_i,
977/// %transformed_j <some operations that use %transformed_i,
978/// %transformed_j>
979/// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
980/// }
981/// }
982///
983/// TODO: Investigate whether mixing implicit and explicit indices
984/// does not lead to losing information.
985void transformIndexOps(RewriterBase &b, LinalgOp op,
986 SmallVectorImpl<Value> &ivs,
987 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex);
988
989/// Apply transformation to split the single linalg op reduction into a
990/// parallel and reduction dimension. Then create a new linalg.generic op
991/// doing the rest of the reduction. Return the new linalg op with an extra
992/// parallel dimension or failure if the transformation didn't happen.
993///
994/// Example:
995/// ```
996/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
997/// affine_map<(d0) -> ()>],
998/// iterator_types = ["reduction"]}
999/// ins(%in : tensor<32xf32>)
1000/// outs(%out : tensor<f32>) {
1001/// ^bb0(%arg1: f32, %arg2: f32):
1002/// %y = arith.addf %arg1, %arg2 : f32
1003/// linalg.yield %y : f32
1004/// } -> tensor<f32>
1005/// ```
1006/// To:
1007/// ```
1008/// %cst = arith.constant 0.000000e+00 : f32
1009/// %0 = tensor.expand_shape %in [[0, 1]]: tensor<32xf32> into tensor<4x8xf32>
1010/// %1 = tensor.empty [4] : tensor<4xf32>
1011/// %2 = linalg.fill ins(%cst : f32)
1012/// outs(%1 : tensor<4xf32>) -> tensor<4xf32>
1013/// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1014/// affine_map<(d0, d1) -> (d0)>],
1015/// iterator_types = ["parallel", "reduction"]}
1016/// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
1017/// ^bb0(%arg3: f32, %arg5: f32):
1018/// %5 = arith.addf %arg3, %arg4 : f32
1019/// linalg.yield %5 : f32
1020/// } -> tensor<4xf32>
1021/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
1022/// affine_map<(d0) -> ()>],
1023/// iterator_types = ["reduction"]}
1024/// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
1025/// ^bb0(%arg3: f32, %arg4: f32):
1026/// %5 = arith.addf %arg3, %arg4 : f32
1027/// linalg.yield %5 : f32
1028/// } -> tensor<f32>
1029/// ```
1030struct SplitReductionResult {
1031 Operation *initOrAlloc;
1032 FillOp fillOp;
1033 LinalgOp splitLinalgOp;
1034 LinalgOp resultCombiningLinalgOp;
1035};
1036FailureOr<SplitReductionResult>
1037splitReduction(RewriterBase &b, LinalgOp op,
1038 const ControlSplitReductionFn &controlSplitReductionFn,
1039 bool useAlloc = false);
1040
1041/// Scaling-based implementation of the split reduction transformation.
1042/// Instead of introducing an ExpandShapeOp, this rewrites a reduction
1043/// dimension `k` into `k * scale + kk`.
1044///
1045/// Example:
1046/// ```
1047/// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
1048/// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
1049/// ```
1050///
1051/// Is transformed to:
1052///
1053/// ```
1054/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)>
1055/// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)>
1056/// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
1057/// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1058/// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
1059/// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
1060/// %0 = tensor.empty [16, 32, 64] : tensor<16x32x64xf32>
1061/// %cst = arith.constant 0.000000e+00 : f32
1062/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) ->
1063/// tensor<16x32x64xf32>
1064/// %2 = tensor.empty [64, 4] : tensor<64x4xi1>
1065///
1066/// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3],
1067/// iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
1068/// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>,
1069/// tensor<64x4xi1>)
1070/// outs(%1 : tensor<16x32x64xf32>) {
1071/// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32):
1072/// %5 = arith.mulf %arg3, %arg4 : f32
1073/// %6 = arith.addf %arg6, %5 : f32
1074/// linalg.yield %6 : f32
1075/// } -> tensor<16x32x64xf32>
1076///
1077/// %4 = linalg.generic {indexing_maps = [#map4, #map5],
1078/// iterator_types = ["parallel", "parallel", "reduction"]}
1079// ins(%3 : tensor<16x32x64xf32>)
1080/// outs(%C : tensor<16x32xf32>) {
1081/// ^bb0(%arg3: f32, %arg4: f32):
1082/// %5 = arith.addf %arg3, %arg4 : f32
1083/// linalg.yield %5 : f32
1084/// } -> tensor<16x32xf32>
1085///
1086/// return %4 : tensor<16x32xf32>
1087/// ```
1088FailureOr<SplitReductionResult>
1089splitReductionByScaling(RewriterBase &b, LinalgOp op,
1090 const ControlSplitReductionFn &controlSplitReductionFn,
1091 bool useAlloc = false);
1092
1093/// Return `true` if a given sequence of dimensions are contiguous in the
1094/// range of the specified indexing map.
1095bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
1096/// Return `true` if all sequences of dimensions specified in `dimSequences` are
1097/// contiguous in all the ranges of the `maps`.
1098bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
1099 ArrayRef<ReassociationIndices> dimSequences);
1100
1101struct CollapseResult {
1102 SmallVector<Value> results;
1103 LinalgOp collapsedOp;
1104};
1105
1106/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
1107/// to calling this method is that for each list in `foldedIterationDim`, the
1108/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
1109/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
1110/// When valid, the method also collapses the operands of the op. Returns
1111/// replacement values of the results of the original `linalgOp` by inserting
1112/// reshapes to get back values of compatible types.
1113FailureOr<CollapseResult>
1114collapseOpIterationDims(LinalgOp op,
1115 ArrayRef<ReassociationIndices> foldedIterationDims,
1116 RewriterBase &rewriter);
1117
1118struct LowerPackResult {
1119 tensor::PadOp padOp;
1120 tensor::ExpandShapeOp expandShapeOp;
1121 linalg::TransposeOp transposeOp;
1122};
1123
1124/// Rewrite pack as pad + reshape + transpose.
1125FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
1126 linalg::PackOp packOp,
1127 bool lowerPadLikeWithInsertSlice = true);
1128
1129struct LowerUnPackOpResult {
1130 tensor::EmptyOp emptyOp;
1131 linalg::TransposeOp transposeOp;
1132 tensor::CollapseShapeOp collapseShapeOp;
1133 tensor::ExtractSliceOp extractSliceOp;
1134};
1135
1136/// Rewrite pack as empty + transpose + reshape + extract_slice.
1137FailureOr<LowerUnPackOpResult>
1138lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
1139 bool lowerUnpadLikeWithExtractSlice = true);
1140
1141/// Struct to hold the result of a `pack` call.
1142struct PackResult {
1143 SmallVector<linalg::PackOp> packOps;
1144 linalg::LinalgOp packedLinalgOp;
1145 SmallVector<linalg::UnPackOp> unPackOps;
1146};
1147/// Implement packing of a single LinalgOp by `packedSizes`.
1148/// There must be one packedSizes entry per `linalgOp` iterator.
1149/// Return the packed Linalg op on success, failure otherwise.
1150FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
1151 ArrayRef<OpFoldResult> packedSizes);
1152
1153/// Struct to hold the result of a `packTranspose` call.
1154struct PackTransposeResult {
1155 linalg::PackOp transposedPackOp;
1156 linalg::LinalgOp transposedLinalgOp;
1157 linalg::UnPackOp transposedUnPackOp;
1158};
1159/// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the
1160/// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements.
1161/// Return failure if either:
1162/// 1. the `packOp` does not have the `linalgOp` as its unique use.
1163/// 2. the `maybeUnPackOp`, if specified must be a consumer of the result tied
1164/// to the unique `packOp` use.
1165/// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of
1166/// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty.
1167FailureOr<PackTransposeResult>
1168packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
1169 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
1170 ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);
1171
1172/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
1173/// and n are proper parallel dimensions and k is a proper reduction
1174/// dimension. Packing occurs by rewriting the op as a linalg.generic and
1175/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
1176/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
1177/// to reorder {m, n, k} into one of the 8 possible forms. The outer
1178/// dimensions of the operands are not permuted at this time, this is left for
1179/// future work.
1180FailureOr<PackResult>
1181packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
1182 ArrayRef<OpFoldResult> mnkPackedSizes,
1183 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
1184 ArrayRef<int64_t> mnkOrder);
1185
1186struct BlockPackMatmulOptions {
1187 /// Minor block factors (mb, nb, kb) for packing relayout where mb, mn are
1188 /// the parallel dimensions and kb is the reduction dimension.
1189 SmallVector<int64_t, 3> blockFactors;
1190
1191 /// If true, allows packing of dimensions that only partially fit into the
1192 /// block factors.
1193 bool allowPadding = true;
1194
1195 /// Next multiples of the packing sizes.
1196 SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
1197
1198 /// Permutation of matmul (M, N, K) dimensions order.
1199 SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
1200
1201 /// Transpose LHS outer block layout [MB][KB] -> [KB][MB].
1202 bool lhsTransposeOuterBlocks = false;
1203
1204 /// Transpose LHS inner block layout [mb][kb] -> [kb][mb].
1205 bool lhsTransposeInnerBlocks = false;
1206
1207 /// Transpose RHS outer block layout [KB][NB] -> [NB][KB].
1208 bool rhsTransposeOuterBlocks = true;
1209
1210 /// Transpose RHS inner block layout [kb][nb] -> [nb][kb].
1211 bool rhsTransposeInnerBlocks = true;
1212};
1213
1214/// Function type which is used to control matmul packing.
1215/// It is expected to return valid packing configuration for each operation.
1216/// Lack of packing options indicates that no valid configuration could be
1217/// assigned and the operation will not be packed.
1218using ControlBlockPackMatmulFn =
1219 std::function<std::optional<BlockPackMatmulOptions>(linalg::LinalgOp)>;
1220
1221/// Pack a matmul operation into blocked 4D layout.
1222///
1223/// Relayout a matmul operation into blocked layout with two levels of
1224/// subdivision:
1225/// - major 2D blocks - outer dimensions, consist of minor blocks
1226/// - minor 2D blocks - inner dimensions, consist of scalar elements
1227///
1228/// A 2D matmul MxNxK gets reshaped into blocked 4D representation
1229/// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
1230/// where the (MB, NB, KB) dimensions represent the major blocks,
1231/// and the (mb, nb, kb) are the minor blocks of their respective
1232/// original 2D dimensions (M, N, K).
1233///
1234/// Depending on the initial operands' data layout and the specified
1235/// packing options, the major blocks dimensions might get transposed
1236/// e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
1237/// e.g., [mb][kb] -> [kb][mb].
1238/// Any present batch dimensions remain unchanged.
1239/// The final result is unpacked back to the original shape.
1240///
1241/// Return failure if no valid packing options are provided.
1242FailureOr<PackResult>
1243blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
1244 const ControlBlockPackMatmulFn &controlPackMatmul);
1245
1246/// Rewrite tensor.from_elements to linalg.generic.
1247FailureOr<Operation *>
1248rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1249 tensor::FromElementsOp fromElementsOp);
1250
1251/// Rewrite tensor.generate to linalg.generic.
1252FailureOr<Operation *>
1253rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1254 tensor::GenerateOp generateOp);
1255
1256/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice.
1257FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1258 tensor::PadOp padOp);
1259
1260/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
1261/// and linalg.matmul.
1262///
1263/// A convolution operation can be written as a matrix-matrix multiplication by
1264/// unfolding the cross-correlation between input and filter and explicitly copy
1265/// overlapped sliding window inputs.
1266///
1267/// Consider 2D input X with single channel input and output and 2x2 filter W:
1268/// [x(0, 0) , x(0, 1) , ..., x(0, n) ]
1269/// [x(1, 0) , x(1, 1) , ..., x(1, n) ]
1270/// [. , . ,. , . ] [w(0, 0), w(0, 1)]
1271/// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)]
1272/// [. , . , ., . ]
1273/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
1274///
1275/// The packed input data (img2col) is a matrix with |rows| = output spatial
1276/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need
1277/// to calculate the dot product between filter window at input X(x, y)) and the
1278/// filter which will look like the following where r.h.s is the img2col matrix
1279/// and l.h.s is the flattened filter:
1280///
1281/// [x(0,0), x(0,1), x(1,0), x(1,1)]
1282/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
1283/// [x(0,1), x(1,1), x(0,2), x(1,2)]
1284/// [ . , . , . , . ]
1285///
1286/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
1287/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
1288/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
1289/// the N input. For the case where N > 1 its a batched matrix-matrix
1290/// multiplication.
1291///
1292/// On success, return both the operation that produces the img2col tensor and
1293/// the final operation of the sequence that replaces the original convolution.
1294FailureOr<std::pair<Operation *, Operation *>>
1295rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
1296
1297/// Same as the above but for Fhwc channel orderings in the filter. In this case
1298/// the matrix multiplication is actually a row-wise dot-product rather than a
1299/// row-column dot-product. This is to avoid transposing the filter matrix which
1300/// would be required for a regular matrix multiplication to produce the correct
1301/// output dimensions.
1302FailureOr<std::pair<Operation *, Operation *>>
1303rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp);
1304
1305/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
1306/// reduction among the input channels so each convolution can be a
1307/// matrix-vector product and by transposing both input filter so channels are
1308/// outer most the computation is a batched matrix-vector product.
1309FailureOr<std::pair<Operation *, Operation *>>
1310rewriteInIm2Col(RewriterBase &rewriter,
1311 linalg::DepthwiseConv2DNhwcHwcOp convOp);
1312
1313/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the
1314/// channels are to the left of the image shape dimensions, the position of the
1315/// contraction dimension in the resulting matmul is reversed. This swaps the
1316/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) *
1317/// (C x Kh x Kw, Ho x Wo))
1318FailureOr<std::pair<Operation *, Operation *>>
1319rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp);
1320
1321/// Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by
1322/// materializing transpose.
1323FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1324 linalg::Conv2DNhwcFhwcOp op);
1325FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1326 linalg::Conv2DNhwcFhwcQOp op);
1327
1328/// Convert Linalg matmul ops to transposed variants.
1329FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
1330 linalg::MatmulOp op,
1331 bool transposeLHS = true);
1332FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
1333 linalg::BatchMatmulOp op,
1334 bool transposeLHS = true);
1335
1336/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
1337/// F(m x m, r x r). m is the dimension size of output and r is the dimension
1338/// size of filter.
1339FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1340 linalg::Conv2DNhwcFhwcOp op, int64_t m,
1341 int64_t r);
1342
1343/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
1344/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1345/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
1346/// the rewriting, we get
1347///
1348/// scf.for %f = lo_f to hi_f step 1
1349/// scf.for %c = lo_c to hi_c step 1
1350/// %extracted = extract filter<h x w> from filter<f x h x w x c>
1351/// %ret = linalg.matmul G, %extracted
1352/// %ret = linalg.matmul %ret, GT
1353/// %inserted = insert %ret into filter<h x w x c x f>
1354FailureOr<Operation *>
1355decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1356 linalg::WinogradFilterTransformOp op);
1357
1358/// Rewrite linalg.winograd_input_transform. The data layout of the input is
1359/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1360/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
1361/// and tileW. After the rewriting, we get
1362///
1363/// scf.for %h = 0 to tileH step 1
1364/// scf.for %w = 0 to tileW step 1
1365/// scf.for %n = 0 to N step 1
1366/// scf.for %c = 0 to C step 1
1367/// %extracted = extract %extracted<alphaH x alphaW> from
1368/// %input<N x H x W x C>
1369/// at [%n, (%h x m), (%w x m), %c]
1370/// %ret = linalg.matmul BT, %extracted
1371/// %ret = linalg.matmul %ret, B
1372/// %inserted = insert %ret<alphaH x alphaW> into
1373/// %output<alphaH x alphaW x tileH x tileW x N x C>
1374/// at [0, 0, %h, %w, %n, %c]
1375FailureOr<Operation *>
1376decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1377 linalg::WinogradInputTransformOp op);
1378
1379/// Rewrite linalg.winograd_output_transform. The data layout of the output is
1380/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1381/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
1382/// and tileW. After the transformation, we get
1383///
1384/// scf.for %h = 0 to tileH step 1
1385/// scf.for %w = 0 to tileW step 1
1386/// scf.for %n = 0 to N step 1
1387/// scf.for %f = 0 to F step 1
1388/// %extracted = extract %extracted<alphaH x alphaW> from
1389/// %input<alphaH x alphaW x tileH x tileW x N x F>
1390/// at [0, 0, %h, %w, %n, %f]
1391/// %ret = linalg.matmul AT, %extracted
1392/// %ret = linalg.matmul %ret, A
1393/// %inserted = insert %ret<alphaH x alphaW> into
1394/// output<N x H x W x F>
1395/// at [%n, (%h x m), (%w x m), %f]
1396FailureOr<Operation *>
1397decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1398 linalg::WinogradOutputTransformOp op);
1399
1400/// Method to deduplicate operands and remove dead results of `linalg.generic`
1401/// operations. This is effectively DCE for a linalg.generic op. If there is
1402/// deduplication of operands orremoval of results, replaces the `genericOp`
1403/// with a new op and returns it. Returns the same operation if there is no
1404/// deduplication/removal.
1405FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
1406 RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs);
1407
1408//===----------------------------------------------------------------------===//
1409// Rewrite patterns wrapping transformations.
1410// TODO: every single such pattern should be a close to noop wrapper around a
1411// functional-stye API call.
1412//===----------------------------------------------------------------------===//
1413
1414/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
1415/// convolution ops.
1416template <typename Conv2DOp, typename Conv1DOp>
1417struct DownscaleSizeOneWindowed2DConvolution final
1418 : public OpRewritePattern<Conv2DOp> {
1419 using OpRewritePattern<Conv2DOp>::OpRewritePattern;
1420
1421 FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
1422 PatternRewriter &rewriter) const;
1423
1424 LogicalResult matchAndRewrite(Conv2DOp convOp,
1425 PatternRewriter &rewriter) const override {
1426 return returningMatchAndRewrite(convOp, rewriter);
1427 }
1428};
1429
1430extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1431 Conv1DNwcWcfOp>;
1432extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1433 Conv1DNcwFcwOp>;
1434
1435/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1436/// dimensions into 1-D depthwise convolution ops.
1437struct DownscaleDepthwiseConv2DNhwcHwcOp final
1438 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1439 DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
1440 PatternBenefit benefit = 1)
1441 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
1442
1443 FailureOr<DepthwiseConv1DNwcWcOp>
1444 returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1445 PatternRewriter &rewriter) const;
1446
1447 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1448 PatternRewriter &rewriter) const override {
1449 return returningMatchAndRewrite(convOp, rewriter);
1450 }
1451};
1452
1453struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
1454 DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
1455 : OpRewritePattern<Conv2DOp>(context, benefit) {}
1456
1457 FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
1458 PatternRewriter &rewriter) const;
1459
1460 LogicalResult matchAndRewrite(Conv2DOp convOp,
1461 PatternRewriter &rewriter) const override {
1462 return returningMatchAndRewrite(convOp, rewriter);
1463 }
1464};
1465
1466///
1467/// Linalg generalization pattern.
1468///
1469/// Apply the `generalization` transformation as a pattern.
1470/// See `generalization` for more details.
1471//
1472// TODO: Automatic default pattern class that just unwraps a function
1473// returning FailureOr<GenericOp>.
1474struct LinalgGeneralizationPattern
1475 : public OpInterfaceRewritePattern<LinalgOp> {
1476 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1477
1478 /// `matchAndRewrite` implementation that returns the significant
1479 /// transformed pieces of IR.
1480 FailureOr<GenericOp>
1481 returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
1482 return generalizeNamedOp(rewriter, op);
1483 }
1484
1485 LogicalResult matchAndRewrite(LinalgOp op,
1486 PatternRewriter &rewriter) const override {
1487 return returningMatchAndRewrite(op, rewriter);
1488 }
1489};
1490
1491struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> {
1492 using OpRewritePattern<GenericOp>::OpRewritePattern;
1493
1494 FailureOr<GenericOp>
1495 returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
1496 return specializeGenericOp(rewriter, op);
1497 }
1498
1499 LogicalResult matchAndRewrite(GenericOp op,
1500 PatternRewriter &rewriter) const override {
1501 return returningMatchAndRewrite(op: op, rewriter);
1502 }
1503};
1504
1505/// Vectorization pattern for memref::CopyOp.
1506struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
1507 using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
1508
1509 LogicalResult matchAndRewrite(memref::CopyOp copyOp,
1510 PatternRewriter &rewriter) const override;
1511};
1512
1513using OptimizeCopyFn =
1514 std::function<LogicalResult(RewriterBase &, tensor::PadOp, Value)>;
1515
1516/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
1517/// InsertSliceOp. For now, only constant padding values are supported.
1518struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
1519 DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
1520 : OpRewritePattern<tensor::PadOp>(context, benefit) {}
1521 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1522 PatternRewriter &rewriter) const override;
1523
1524protected:
1525 Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
1526 Value dest,
1527 const SmallVector<Value> &dynSizes) const;
1528};
1529
1530/// Rewrites a linalg::PackOp into a sequence of:
1531/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
1532/// tensor::InsertSliceOp ops.
1533///
1534/// Requires that all the outer dims of the input linalg::PackOp are 1.
1535///
1536/// Before:
1537/// ```
1538/// %packed = linalg.pack %input
1539/// padding_value(%pad : f32)
1540/// inner_dims_pos = [1, 0]
1541/// inner_tiles = [2, %high]
1542/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
1543/// ```
1544///
1545/// After:
1546/// ```
1547/// // PadOp
1548/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
1549/// ^bb0(...):
1550/// tensor.yield %arg2 : f32
1551/// } : tensor<5x1xf32> to tensor<?x2xf32>
1552/// // EmptyOp + TransposeOp
1553/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
1554/// %transposed = linalg.transpose
1555/// ins(%extracted_slice : tensor<?x2xf32>)
1556/// outs(%empty : tensor<2x?xf32>)
1557/// permutation = [1, 0]
1558/// // InsertSliceOp
1559/// %inserted_slice = tensor.insert_slice %transposed
1560/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
1561/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
1562/// ```
1563struct DecomposeOuterUnitDimsPackOpPattern
1564 : public OpRewritePattern<linalg::PackOp> {
1565 using OpRewritePattern<linalg::PackOp>::OpRewritePattern;
1566 LogicalResult matchAndRewrite(linalg::PackOp packOp,
1567 PatternRewriter &rewriter) const override;
1568};
1569
1570/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
1571/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
1572///
1573/// Requires that all the outer dims of the input linalg::PackOp are 1.
1574///
1575/// Before:
1576/// ```
1577/// %packed = linalg.unpack %input
1578/// inner_dims_pos = [1, 0]
1579/// inner_tiles = [2, 8]
1580/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
1581/// ```
1582///
1583/// After:
1584/// ```
1585/// // Rank-reduced extract to obtain the tile
1586/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1]
1587/// : tensor<1x1x2x8xf32> to tensor<2x8xf32>
1588/// // EmptyOp + TransposeOp
1589/// %init = tensor.empty() : tensor<8x2xf32>
1590/// %transposed = linalg.transpose
1591/// ins(%extracted_slice : tensor<2x8xf32>)
1592/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0]
1593/// // Extract a slice matching the specified output size
1594/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1]
1595/// : tensor<8x2xf32> to tensor<5x1xf32>
1596/// ```
1597struct DecomposeOuterUnitDimsUnPackOpPattern
1598 : public OpRewritePattern<linalg::UnPackOp> {
1599 using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern;
1600 LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp,
1601 PatternRewriter &rewriter) const override;
1602};
1603
1604/// Match and rewrite for the pattern:
1605/// ```
1606/// %alloc = ...
1607/// [optional] %view = memref.view %alloc ...
1608/// %subView = subview %allocOrView ...
1609/// [optional] linalg.fill(%allocOrView, %cst) ...
1610/// ...
1611/// memref.copy(%in, %subView) ...
1612/// vector.transfer_read %allocOrView[...], %cst ...
1613/// ```
1614/// into
1615/// ```
1616/// [unchanged] %alloc = ...
1617/// [unchanged] [optional] %view = memref.view %alloc ...
1618/// [unchanged] [unchanged] %subView = subview %allocOrView ...
1619/// ...
1620/// vector.transfer_read %in[...], %cst ...
1621/// ```
1622/// Where there is no interleaved use between memref.copy and transfer_read as
1623/// well as no interleaved use between linalg.fill and memref.copy (if
1624/// linalg.fill is specified).
1625/// This is a custom rewrite to forward partial reads (with optional fills) to
1626/// vector.transfer_read.
1627struct LinalgCopyVTRForwardingPattern
1628 : public OpRewritePattern<vector::TransferReadOp> {
1629 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
1630
1631 LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
1632 PatternRewriter &rewriter) const override;
1633};
1634
1635/// Match and rewrite for the pattern:
1636/// ```
1637/// %alloc = ...
1638/// [optional] %view = memref.view %alloc ...
1639/// %subView = subview %allocOrView...
1640/// ...
1641/// vector.transfer_write %..., %allocOrView[...]
1642/// memref.copy(%subView, %out)
1643/// ```
1644/// into
1645/// ```
1646/// [unchanged] %alloc = ...
1647/// [unchanged] [optional] %view = memref.view %alloc ...
1648/// [unchanged] %subView = subview %allocOrView...
1649/// ...
1650/// vector.transfer_write %..., %out[...]
1651/// ```
1652/// Where there is no interleaved use between transfer_write and memref.copy.
1653/// This is a custom rewrite to forward partial writes to
1654/// vector.transfer_write.
1655struct LinalgCopyVTWForwardingPattern
1656 : public OpRewritePattern<vector::TransferWriteOp> {
1657 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
1658
1659 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1660 PatternRewriter &rewriter) const override;
1661};
1662
1663/// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)).
1664struct ExtractSliceOfPadTensorSwapPattern
1665 : public OpRewritePattern<tensor::ExtractSliceOp> {
1666 /// A function to control pattern application and rewrite logic.
1667 ///
1668 /// The function will be given the slice op and should return:
1669 /// - std::nullopt: to fail the match and not apply the pattern;
1670 /// - true: to apply the pattern with zero slice guard;
1671 /// - false: to apply the pattern without zero slice guard.
1672 ///
1673 /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice
1674 /// guard.
1675 using ControlFn = std::function<std::optional<bool>(tensor::ExtractSliceOp)>;
1676
1677 ExtractSliceOfPadTensorSwapPattern(MLIRContext *context,
1678 ControlFn controlFn = nullptr,
1679 PatternBenefit benefit = 1)
1680 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
1681
1682 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1683 PatternRewriter &rewriter) const override;
1684
1685private:
1686 ControlFn controlFn;
1687};
1688
1689//===----------------------------------------------------------------------===//
1690// Populate functions.
1691//===----------------------------------------------------------------------===//
1692
1693/// Canonicalization patterns relevant to apply after tiling patterns. These
1694/// are applied automatically by the tiling pass but need to be applied
1695/// manually when tiling is called programmatically.
1696RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
1697void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
1698
1699/// Linalg generalization patterns
1700
1701/// Populates `patterns` with patterns to convert spec-generated named ops to
1702/// linalg.generic ops.
1703void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
1704
1705/// Populates `patterns` with patterns to convert linalg.generic ops to named
1706/// ops where possible. A linalg.generic can represent wide range and complex
1707/// computations for which equivalent linalg named op may not exist e.g.
1708/// linalg.generic that takes a tensor and computes a polynomial such as:
1709/// p(x) = an*x^n + ... + a1x + a0
1710/// There is no equivalent named op to convert to. Many such cases exist.
1711void populateLinalgGenericOpsSpecializationPatterns(
1712 RewritePatternSet &patterns);
1713
1714/// Populates `patterns` with patterns that fold operations like
1715/// `linalg.transform` into elementwise op map.
1716void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
1717
1718/// Linalg decompose convolutions patterns
1719
1720/// Populates patterns to decompose high-D convolution ops into low-D ones.
1721/// This is a step in progressive lowering for convolution ops, afterwards we
1722/// can vectorize the low-D convolution ops.
1723void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
1724 PatternBenefit benefit = 1);
1725
1726/// Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
1727/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
1728/// outer dims to be unit.
1729void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);
1730
1731/// Populates patterns to decompose tensor.pad into e.g.
1732/// tensor.empty, linalg.fill, tensor.insert_slice.
1733void populateDecomposePadPatterns(RewritePatternSet &patterns);
1734
1735/// Populates patterns to transform linalg.conv_2d_xxx operations into
1736/// linalg.generic (for img2col packing) and linalg.matmul.
1737/// \see rewriteInIm2Col for more details.
1738void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
1739
1740/// Populates `patterns` with patterns that vectorize tensor.pad.
1741/// These patterns are meant to apply in a complementary fashion. Benefits
1742/// are used to encode a certain ordering of pattern application. To avoid
1743/// scattering magic constants throughout the code base, the patterns must be
1744/// added with this function. `baseBenefit` can be used to offset the benefit
1745/// of all tensor::PadOp vectorization patterns by a certain value.
1746void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
1747 PatternBenefit baseBenefit = 1);
1748
1749/// Populate patterns for splitting a `LinalgOp` with multiple statements within
1750/// its payload into multiple `GenericOp` that have a single statement.
1751/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments
1752/// and results from the generated decomposed ops. This is default `true` since
1753/// the core decomposition patterns relies on these clean up patterns. It is set
1754/// to false only for testing purposes.
1755void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns,
1756 bool removeDeadArgsAndResults = true);
1757
1758/// Populate patterns that convert non-destination-style ops to destination
1759/// style ops.
1760void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns);
1761
1762/// Populate patterns for vectorizing low-D convolution ops. This is a step in
1763/// progressive lowering for convolution ops, it assume high-D convolution ops
1764/// were decomposed previously.
1765void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
1766 PatternBenefit benefit = 1);
1767
1768/// Populate patterns that convert `ElementwiseMappable` ops to linalg
1769/// parallel loops.
1770void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
1771
1772/// Populate patterns that are only useful in the context of sparse tensors.
1773void populateSparseTensorRewriting(RewritePatternSet &patterns);
1774
1775/// Function type which is used to control when to stop fusion. It is expected
1776/// that OpOperand is not modified in the callback. The OpOperand is not marked
1777/// as const to allow callers to use non-const methods.
1778using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
1779
1780/// Patterns for fusing linalg operation on tensors.
1781
1782/// Pattern to fuse `linalg.generic` -> `linalg.generic` operations
1783/// when both operations are fusable elementwise operations.
1784void populateElementwiseOpsFusionPatterns(
1785 RewritePatternSet &patterns,
1786 const ControlFusionFn &controlElementwiseOpFusion);
1787
1788/// Function type which is used to control propagation of linalg.pack/unpack
1789/// ops.
1790using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
1791
1792/// Patterns to bubble up or down data layout ops across other operations.
1793void populateDataLayoutPropagationPatterns(
1794 RewritePatternSet &patterns,
1795 const ControlPropagationFn &controlPackUnPackPropagation);
1796
1797/// Pattern to remove dead operands and results of `linalg.generic` operations.
1798/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
1799void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
1800
1801/// Patterns to promote inputs to outputs and remove unused inputs of
1802/// `linalg.generic` ops.
1803void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
1804
1805/// Function type to control generic op dimension collapsing. It is expected
1806/// to return an array of `ReassociationIndices` representing dimensions that
1807/// should be merged.
1808using GetCollapsableDimensionsFn =
1809 std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
1810
1811/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
1812/// tensor operands when needed and expand back the result tensors.
1813void populateCollapseDimensions(
1814 RewritePatternSet &patterns,
1815 const GetCollapsableDimensionsFn &controlCollapseDimensions);
1816
1817/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
1818/// producer (consumer) generic operation by expanding the dimensionality of the
1819/// loop in the generic op.
1820void populateFoldReshapeOpsByExpansionPatterns(
1821 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes);
1822
1823/// Patterns to fold an expanding tensor.expand_shape operation with its
1824/// producer generic operation by collapsing the dimensions of the generic op.
1825void populateFoldReshapeOpsByCollapsingPatterns(
1826 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes);
1827
1828/// Patterns to constant fold Linalg operations.
1829void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
1830 const ControlFusionFn &controlFn);
1831
1832/// Pattern to replace `linalg.add` when destination passing on a contraction op
1833/// suffices for achieving the sum.
1834void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
1835
1836/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
1837/// if the producer is a `linalg` operation with all parallel iterator types.
1838void populateFuseTensorPadWithProducerLinalgOpPatterns(
1839 RewritePatternSet &patterns);
1840
1841/// Patterns to convert from one named op to another. These can be seen as
1842/// canonicalizations of named ops into another named op.
1843void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
1844
1845/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
1846/// tensors via reassociative reshape ops.
1847void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
1848 ControlDropUnitDims &options);
1849
1850/// A pattern that converts init operands to input operands.
1851void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);
1852
1853/// Patterns that are used to inline constant operands into linalg generic ops.
1854void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
1855
1856/// Patterns that are used to bubble up extract slice op above linalg op.
1857void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
1858
1859/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into
1860/// linalg.fill(%cst, tensor.extract_slice(%init)).
1861void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);
1862
1863/// Add patterns to make explicit broadcasts and transforms in the
1864/// input operands of a genericOp.
1865void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns);
1866
1867/// Patterns to apply `splitReduction` below.
1868void populateSplitReductionPattern(
1869 RewritePatternSet &patterns,
1870 const ControlSplitReductionFn &controlSplitReductionFn,
1871 bool useAlloc = false);
1872
1873/// Patterns to convert Linalg matmul ops to transposed variants.
1874void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
1875 bool transposeLHS = true);
1876
1877/// Patterns to block pack Linalg matmul ops.
1878void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
1879 const ControlBlockPackMatmulFn &controlFn);
1880
1881/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
1882void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1883 int64_t r);
1884
1885/// Patterns to decompose Winograd operators.
1886void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
1887
1888/// Adds patterns that reduce the rank of named contraction ops that have
1889/// unit dimensions in the operand(s) by converting to a sequence of
1890/// `collapse_shape`,
1891/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For
1892/// example a `linalg.batch_matmul` with unit batch size will convert to
1893/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
1894/// convert to a `linalg.dot`.
1895void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
1896
1897/// Populates `patterns` with patterns that fold operations like `tensor.pad`
1898/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
1899/// respectively.
1900void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
1901
1902/// Populates `patterns` with patterns that fold operations like `linalg.pack`
1903/// and `linalg.unpack` into `tensor.empty`.
1904void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns);
1905
1906/// Populates `patterns` with patterns that simplify `tensor.pack` and
1907/// `tensor.unpack` operations.
1908void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
1909
1910} // namespace linalg
1911} // namespace mlir
1912
1913#endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
1914

source code of mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h