1//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
10#define MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
11
12#include <utility>
13
14#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Linalg/Utils/Utils.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/Utils/Utils.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22#include "mlir/Dialect/X86Vector/Transforms.h"
23#include "mlir/IR/OpDefinition.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Interfaces/TilingInterface.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/ADT/SmallSet.h"
29
30namespace mlir {
31namespace bufferization {
32class AllocTensorOp;
33class OneShotAnalysisState;
34class BufferizationState;
35} // namespace bufferization
36
37namespace linalg {
38
39class LinalgOp;
40enum class WinogradConv2DFmr : uint32_t;
41
42//===----------------------------------------------------------------------===//
43// Utils.
44//===----------------------------------------------------------------------===//
45
46/// Return vector::CombiningKind for the given op.
47std::optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
48
49//===----------------------------------------------------------------------===//
50// Bufferization-related transforms.
51//===----------------------------------------------------------------------===//
52
53struct BufferizeToAllocationOptions {
54 enum class AllocOp { MemrefAlloc = 0, MemrefAlloca = 1 };
55 AllocOp allocOp = AllocOp::MemrefAlloc;
56
57 enum class MemcpyOp {
58 MaterializeInDestination = 0,
59 MemrefCopy = 1,
60 LinalgCopy = 2
61 };
62 MemcpyOp memcpyOp = MemcpyOp::MaterializeInDestination;
63
64 /// If set to "true", only the destination tensor operands are bufferized to
65 /// a new allocation (and wrapped in "bufferization.to_tensor"), but not the
66 /// targeted op itself.
67 bool bufferizeDestinationOnly = false;
68
69 /// If set to "true", a `memref.dealloc` operation will be emitted for each
70 /// allocated buffer. Otherwise, the memory is leaked, which is useful if
71 /// the buffer deallocation pipeline should be run after bufferization is
72 /// done.
73 bool emitDealloc = false;
74};
75
76/// Materialize a buffer allocation for the given tensor.pad op and lower the
77/// op to linalg.fill/linalg.generic + bufferization.materialize_in_destination.
78/// E.g.:
79///
80/// %0 = tensor.pad low[%l] high[%h] %t ...
81///
82/// is lowered to:
83///
84/// %alloc = memref.alloc
85/// linalg.fill ... outs(%alloc)
86/// %subview = memref.subview %alloc [%l] [...] [1]
87/// bufferization.materialize_in_destination %t in %subview
88/// %0 = bufferization.to_tensor %alloc restrict writable
89///
90/// In addition to rewriting the IR as shown above, this function returns the
91/// newly allocated buffer. The `insertionPoint` parameter can be used to
92/// specify a custom insertion point for the buffer allocation.
93Value bufferizeToAllocation(RewriterBase &rewriter,
94 const BufferizeToAllocationOptions &options,
95 tensor::PadOp padOp, Attribute memorySpace = {},
96 Operation *insertionPoint = nullptr);
97
98/// Materialize a buffer allocation for the given vector.mask op and bufferize
99/// the op, including its region. E.g.:
100///
101/// %0 = vector.mask {
102/// vector.transfer_write %v, %t : vector<16xf32>, tensor<?xf32>
103/// } : vector<16xi1> -> tensor<?xf32>
104///
105/// is lowered to:
106///
107/// %alloc = memref.alloc
108/// bufferization.materialize_in_destination %t in %subview
109/// vector.mask {
110/// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32>
111/// } : vector<16xi1>
112/// %0 = bufferization.to_tensor %alloc restrict writable
113///
114/// In addition to rewriting the IR as shown above, this function returns the
115/// newly allocated buffer. The `insertionPoint` parameter can be used to
116/// specify a custom insertion point for the buffer allocation.
117Value bufferizeToAllocation(RewriterBase &rewriter,
118 const BufferizeToAllocationOptions &options,
119 vector::MaskOp maskOp, Attribute memorySpace = {},
120 Operation *insertionPoint = nullptr);
121
122/// Materialize a buffer allocation for the given bufferization.alloc_tensor op
123/// and lower the op to memref.alloc + memref.tensor_store.
124///
125/// In addition to rewriting the IR, this function returns the newly allocated
126/// buffer. The `insertionPoint` parameter can be used to specify a custom
127/// insertion point for the buffer allocation.
128Value bufferizeToAllocation(RewriterBase &rewriter,
129 const BufferizeToAllocationOptions &options,
130 bufferization::AllocTensorOp allocTensorOp,
131 Attribute memorySpace = {},
132 Operation *insertionPoint = nullptr);
133
134/// Bufferize the given op with tensor semantics and materialize the result in
135/// a newly allocated buffer.
136///
137/// Only bufferizable ops that bufferize to a memory write or have an
138/// aliasing OpOperand (and do not themselves bufferize to an allocation) are
139/// supported. They are bufferized using their BufferizableOpInterface
140/// implementation.
141///
142/// Selected ops that bufferize to an allocation (or need special handling) are
143/// also supported:
144/// - tensor.pad
145/// - vector.mask
146///
147/// This function returns the newly allocated buffer. The `insertionPoint`
148/// parameter can be used to specify a custom insertion point for the buffer
149/// allocation.
150Value bufferizeToAllocation(RewriterBase &rewriter,
151 const BufferizeToAllocationOptions &options,
152 Operation *op, Attribute memorySpace = {},
153 Operation *insertionPoint = nullptr);
154
155/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a
156/// LinalgOp. This transforms looks for LinalgOps that have an unused output
157/// operand and an input operand that is rooted in a tensor::EmptyOp. The
158/// tensor::EmptyOp uses are replaced with the output operand and the two
159/// operands of the LinalgOp are swapped.
160///
161/// Example:
162/// %0 = tensor.empty()
163/// %1 = linalg.matmul ins(...) outs(%0)
164/// %2 = linalg.generic ins(%1) outs(%dest) {
165/// ^bb0(%in: f32, %out: f32):
166/// // out not used
167/// }
168///
169/// The IR is transformed as follows:
170/// %0 = tensor.empty()
171/// %1 = linalg.matmul ins(...) outs(%dest)
172/// %2 = linalg.generic ins(%0) outs(%1) {
173/// ^bb0(%in: f32, %out: f32):
174/// // Use %out instead of %in
175/// }
176///
177/// The "ins" operand has no uses inside the body of the LinalgOp and can be
178/// folded away with existing cleanup patterns. Afterwards, the tensor::EmptyOp
179/// can also fold away.
180LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(
181 RewriterBase &rewriter, Operation *op,
182 bufferization::OneShotAnalysisState &state);
183
184//===----------------------------------------------------------------------===//
185// Structs that configure the behavior of various transformations.
186//===----------------------------------------------------------------------===//
187
188using TileSizeComputationFunction =
189 std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
190
191struct LinalgTilingOptions {
192 /// Computation function that returns the tile sizes for each operation.
193 /// Delayed construction of constant tile sizes should occur to interoperate
194 /// with folding.
195 TileSizeComputationFunction tileSizeComputationFunction = nullptr;
196
197 LinalgTilingOptions &
198 setTileSizeComputationFunction(TileSizeComputationFunction fun) {
199 tileSizeComputationFunction = std::move(fun);
200 return *this;
201 }
202 /// Set the `tileSizeComputationFunction` to return the values `ts`. The
203 /// values must not fold away when tiling. Otherwise, use a more robust
204 /// `tileSizeComputationFunction`.
205 LinalgTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
206 tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
207 return *this;
208 }
209 /// Convenience function to set the `tileSizeComputationFunction` to a
210 /// function that computes tile sizes at the point they are needed. Allows
211 /// proper interaction with folding.
212 LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
213
214 /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions.
215 /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together.
216 LinalgTilingOptions &scalarizeDynamicDims();
217
218 /// The interchange vector to reorder the tiled loops.
219 SmallVector<unsigned, 4> interchangeVector = {};
220
221 LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
222 interchangeVector.assign(in_start: interchange.begin(), in_end: interchange.end());
223 return *this;
224 }
225
226 /// The type of tile loops to generate.
227 LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops;
228
229 LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) {
230 loopType = lt;
231 return *this;
232 }
233
234 /// When specified, specifies distribution of generated tile loops to
235 /// processors.
236 std::optional<LinalgLoopDistributionOptions> distribution;
237
238 LinalgTilingOptions &
239 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
240 distribution = std::move(distributionOptions);
241 return *this;
242 }
243
244 /// Specification markers of how to distribute the `linalg.tiled_loop`.
245 SmallVector<StringRef, 2> distributionTypes = {};
246
247 LinalgTilingOptions &setDistributionTypes(ArrayRef<StringRef> types) {
248 distributionTypes.assign(in_start: types.begin(), in_end: types.end());
249 return *this;
250 }
251
252 /// Peel the specified loops.
253 SmallVector<int64_t> peeledLoops;
254
255 LinalgTilingOptions &setPeeledLoops(ArrayRef<int64_t> loops) {
256 peeledLoops.clear();
257 peeledLoops.append(in_start: loops.begin(), in_end: loops.end());
258 return *this;
259 }
260};
261
262struct LinalgTilingAndFusionOptions {
263 /// Tile sizes used to tile the root operation.
264 SmallVector<int64_t> tileSizes;
265 LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) {
266 tileSizes.assign(in_start: ts.begin(), in_end: ts.end());
267 return *this;
268 }
269 /// Tile interchange used to permute the tile loops.
270 SmallVector<int64_t> tileInterchange;
271 /// When specified, specifies distribution of generated tile loops to
272 /// processors.
273 std::optional<LinalgLoopDistributionOptions> tileDistribution;
274 LinalgTilingAndFusionOptions &
275 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
276 tileDistribution = std::move(distributionOptions);
277 return *this;
278 }
279};
280
281struct LinalgPaddingOptions {
282 /// A padding value for every operand.
283 SmallVector<Attribute> paddingValues;
284 LinalgPaddingOptions &setPaddingValues(ArrayRef<Attribute> pv) {
285 paddingValues.assign(in_start: pv.begin(), in_end: pv.end());
286 return *this;
287 }
288 /// A list of iterator dimensions to pad.
289 SmallVector<int64_t> paddingDimensions;
290 LinalgPaddingOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
291 paddingDimensions.assign(in_start: pd.begin(), in_end: pd.end());
292 return *this;
293 }
294 /// A list of multiples to which each padding dimension should be padded to.
295 std::optional<SmallVector<int64_t>> padToMultipleOf;
296 LinalgPaddingOptions &setPadToMultipleOf(ArrayRef<int64_t> m) {
297 padToMultipleOf.emplace(args: m.begin(), args: m.end());
298 return *this;
299 }
300 /// A mapping between an operand and shape dim, and a size for a padding
301 /// dimension. Each size is expected to be greater or equal than the
302 /// corresponding shape dim. If no value is provided then the constant upper
303 /// bound will be used.
304 DenseMap<std::pair<unsigned, unsigned>, OpFoldResult> sizeToPadTo;
305 LinalgPaddingOptions &setSizeToPadTo(unsigned operandIndex, unsigned dimIndex,
306 OpFoldResult size) {
307 assert(size && "expected non-null size");
308 sizeToPadTo[{operandIndex, dimIndex}] = size;
309 return *this;
310 }
311 /// Given the operand index and shape dim it returns the size to pad to.
312 OpFoldResult getSizeToPadTo(unsigned operandIndex, unsigned dimIndex) const {
313 return sizeToPadTo.lookup_or(
314 Val: std::pair<unsigned, unsigned>(operandIndex, dimIndex), Default: nullptr);
315 }
316
317 /// A flag for every operand to mark the PadOp as nofold which enables
318 /// packing for statically shaped operands.
319 SmallVector<bool> nofoldFlags;
320 LinalgPaddingOptions &setNofoldFlags(ArrayRef<bool> pp) {
321 nofoldFlags.assign(in_start: pp.begin(), in_end: pp.end());
322 return *this;
323 }
324 /// A number of loops to hoist the PadOp out for every operand.
325 SmallVector<int64_t> hoistPaddings;
326 LinalgPaddingOptions &setHoistPaddings(ArrayRef<int64_t> hp) {
327 hoistPaddings.assign(in_start: hp.begin(), in_end: hp.end());
328 return *this;
329 }
330 /// A permutation vector for every operand used to transpose the packed
331 /// PadOp results.
332 SmallVector<SmallVector<int64_t>> transposePaddings;
333 LinalgPaddingOptions &
334 setTransposePaddings(ArrayRef<SmallVector<int64_t>> tp) {
335 transposePaddings.assign(in_start: tp.begin(), in_end: tp.end());
336 return *this;
337 }
338 enum class CopyBackOp : int8_t {
339 None = 0,
340 BufferizationMaterializeInDestination = 1,
341 LinalgCopy = 2
342 };
343 /// The op to be used for copying the padded result to the original
344 /// destination tensor.
345 CopyBackOp copyBackOp = CopyBackOp::BufferizationMaterializeInDestination;
346 LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) {
347 copyBackOp = op;
348 return *this;
349 }
350};
351
352struct PadTilingInterfaceOptions {
353 /// A padding value for every operand.
354 SmallVector<Attribute> paddingValues;
355 PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
356 paddingValues.assign(in_start: pv.begin(), in_end: pv.end());
357 return *this;
358 }
359 /// A list of iterator dimensions sizes to pad to.
360 SmallVector<OpFoldResult> paddingSizes;
361 PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
362 paddingSizes.assign(in_start: m.begin(), in_end: m.end());
363 return *this;
364 }
365 /// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
366 /// if true. Otherwise pad to `paddingSizes[i]`.
367 bool padToMultipleOf;
368 PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
369 padToMultipleOf = b;
370 return *this;
371 }
372};
373
374/// Callback function type used to perform the allocation for the promoted
375/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
376/// smallest constant value for the size of the buffer needed for each
377/// dimension. If that is not possible, contains the dynamic size of the
378/// subview. The call back should return the buffer to use.
379using AllocBufferCallbackFn = std::function<std::optional<Value>(
380 OpBuilder &b, memref::SubViewOp subView,
381 ArrayRef<Value> boundingSubViewSize, DataLayout &layout)>;
382
383/// Callback function type used to deallocate the buffers used to hold the
384/// promoted subview.
385using DeallocBufferCallbackFn =
386 std::function<LogicalResult(OpBuilder &b, Value buffer)>;
387
388/// Callback function type used to insert copy from original subview to
389/// subview of the promoted region for the read operands/subview of promoted
390/// region to original subview for the results. The copy has to happen from
391/// `src` to `dst`.
392using CopyCallbackFn =
393 std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>;
394
395struct LinalgPromotionOptions {
396 /// Indices of subViews to promote. If `std::nullopt`, try to promote all
397 /// operands.
398 std::optional<DenseSet<unsigned>> operandsToPromote;
399 LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) {
400 operandsToPromote = DenseSet<unsigned>();
401 operandsToPromote->insert_range(R&: operands);
402 return *this;
403 }
404 /// If ith element of `useFullTiles` is true the full view should be used
405 /// for the promoted buffer of the ith operand in `operandsToPromote`.
406 /// Otherwise the partial view will be used. The decision is defaulted to
407 /// `useFullTileBuffersDefault` when `useFullTileBuffers` is std::nullopt and
408 /// for operands missing from `useFullTileBuffers`.
409 std::optional<llvm::SmallBitVector> useFullTileBuffers;
410 LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) {
411 unsigned size = useFullTiles.size();
412 llvm::SmallBitVector tmp(size, false);
413 for (unsigned i = 0; i < size; ++i)
414 tmp[i] = useFullTiles[i];
415 useFullTileBuffers = tmp;
416 return *this;
417 }
418 /// If true all operands unspecified by `useFullTileBuffers` will use the
419 /// full view, otherwise the partial view.
420 bool useFullTileBuffersDefault = false;
421 LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) {
422 useFullTileBuffersDefault = use;
423 return *this;
424 }
425 /// If true, buffers will be allocated with the original subview size. This
426 /// may result in more dynamic allocations, in case of dynamic sizes.
427 bool useOriginalSubviewSize = false;
428 LinalgPromotionOptions &setUseOriginalSubviewSize(bool originalSize) {
429 useOriginalSubviewSize = originalSize;
430 return *this;
431 }
432 /// Alignment of promoted buffer. If `std::nullopt` do not specify alignment.
433 std::optional<unsigned> alignment;
434 LinalgPromotionOptions &setAlignment(unsigned align) {
435 alignment = align;
436 return *this;
437 }
438 /// Memory space of promoted buffer. If `std::nullopt` do not specify memory
439 /// space.
440 std::optional<Attribute> memorySpace;
441 LinalgPromotionOptions &setMemorySpace(Attribute memorySpc) {
442 memorySpace = memorySpc;
443 return *this;
444 }
445 /// Use alloca with the default allocation scheme.
446 bool useAlloca = false;
447 LinalgPromotionOptions &setUseAlloca(bool use) {
448 useAlloca = use;
449 return *this;
450 }
451 /// Callback function to do the allocation of the promoted buffer. If
452 /// std::nullopt, then the default allocation scheme of allocating a
453 /// memref<?xi8> buffer followed by a view operation is used.
454 std::optional<AllocBufferCallbackFn> allocationFn;
455 std::optional<DeallocBufferCallbackFn> deallocationFn;
456 LinalgPromotionOptions &
457 setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn,
458 DeallocBufferCallbackFn const &deallocFn) {
459 allocationFn = allocFn;
460 deallocationFn = deallocFn;
461 return *this;
462 }
463 /// Callback function to do the copy of data to and from the promoted
464 /// subview. If std::nullopt then a memref.copy is used.
465 std::optional<CopyCallbackFn> copyInFn;
466 std::optional<CopyCallbackFn> copyOutFn;
467 LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const &copyIn,
468 CopyCallbackFn const &copyOut) {
469 copyInFn = copyIn;
470 copyOutFn = copyOut;
471 return *this;
472 }
473};
474
475/// Split Reduction options.
476struct SplitReductionOptions {
477 // Ratio used to split the reduction dimension. If the ratio is <= 1,
478 // nothing will be done.
479 int64_t ratio = 0;
480 // Index where the extra dimension is added to the intermediate tensor
481 // shape.
482 unsigned index = 0;
483 // If the inner dimension after splitting is parallel or reduction.
484 bool innerParallel = false;
485};
486
487/// Function signature to control reduction splitting. This returns
488/// `SplitReductionOptions`.
489// TODO: don't use unsigned unless doing bit manipulation.
490using ControlSplitReductionFn =
491 std::function<SplitReductionOptions(LinalgOp op)>;
492
493//===----------------------------------------------------------------------===//
494// Preconditions that ensure the corresponding transformation succeeds and can
495// be applied as a rewrite pattern.
496//===----------------------------------------------------------------------===//
497
498/// Return true if two `linalg.generic` operations with producer/consumer
499/// relationship through `fusedOperand` can be fused using elementwise op
500/// fusion.
501bool areElementwiseOpsFusable(OpOperand *fusedOperand);
502
503/// Promote memref.subviews feeding linalg-on-buffers operations.
504LogicalResult promoteSubviewsPrecondition(Operation *op,
505 LinalgPromotionOptions options);
506
507/// Return success if the operation can be vectorized.
508LogicalResult vectorizeOpPrecondition(Operation *op,
509 ArrayRef<int64_t> inputVectorSizes = {},
510 ArrayRef<bool> inputScalableVecDims = {},
511 bool vectorizeNDExtract = false,
512 bool flatten1DDepthwiseConv = false);
513
514//===----------------------------------------------------------------------===//
515// Transformations exposed as functional-style API calls.
516//===----------------------------------------------------------------------===//
517
518using LinalgLoops = SmallVector<Operation *, 4>;
519
520/// Transformation to drop unit-extent dimensions from `linalg.generic`
521/// operations.
522struct ControlDropUnitDims {
523 enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
524
525 RankReductionStrategy rankReductionStrategy =
526 RankReductionStrategy::ReassociativeReshape;
527
528 using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>;
529 ControlFnTy controlFn = [](Operation *op) {
530 if (auto genericOp = dyn_cast_or_null<GenericOp>(Val: op)) {
531 return llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: genericOp.getNumLoops()));
532 }
533 if (auto padOp = dyn_cast_or_null<tensor::PadOp>(Val: op)) {
534 return llvm::to_vector(
535 Range: llvm::seq<unsigned>(Begin: 0, End: padOp.getSourceType().getRank()));
536 }
537 return SmallVector<unsigned>{};
538 };
539};
540struct DropUnitDimsResult {
541 linalg::GenericOp resultOp;
542 SmallVector<Value> replacements;
543};
544FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
545 GenericOp genericOp,
546 const ControlDropUnitDims &options);
547
548/// Fuse two `linalg.generic` operations that have a producer-consumer
549/// relationship captured through `fusedOperand`. The method expects
550/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
551struct ElementwiseOpFusionResult {
552 Operation *fusedOp;
553 llvm::DenseMap<Value, Value> replacements;
554};
555FailureOr<ElementwiseOpFusionResult>
556fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
557
558/// Returns a set of indices of the producer's results which would
559/// be preserved after the fusion.
560/// * There is a chance that the implementation of the transformation does not
561/// agree with the result of this method. This function gives a prediction based
562/// on an optimized fusion.
563llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
564 GenericOp consumer,
565 OpOperand *fusedOperand);
566
567/// Try to peel and canonicalize loop `op` and return the new result.
568/// Also applies affine_min/max bounds simplification on the fly where relevant.
569// TODO: Add support for scf.parallel and affine.for loops.
570SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
571
572/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
573/// where relevant.
574void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
575
576/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
577/// operands to a static bounding box. The original `opToPad` is cloned and
578/// operates on the padded tensors.
579///
580/// * "options.padToMultipleOf" indicates that each padding dimension should be
581/// padded to the specified multiple.
582/// * Use "options.paddingValues" and "options.nofoldFlags" to set padding
583/// value and nofold attribute of the created tensor::PadOps, respectively.
584/// * The unpadded results (extracted slice of the cloned operation) are
585/// returned via `replacements`.
586/// * The tensor::PadOps are returned via `padOps`.
587/// * "options.copyBackOp" specifies the op type for copying back the unpadded
588/// result to the original destination tensor.
589LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
590 const LinalgPaddingOptions &options,
591 LinalgOp &paddedOp,
592 SmallVector<Value> &replacements,
593 SmallVector<tensor::PadOp> &padOps);
594
595/// Helper function to compute the padded shape of the given value `v` of
596/// `RankedTensorType` given:
597/// - the `indexingSizes` as a list of OpFoldResult.
598/// - an `indexingMap` that encodes how the padded shape varies with
599/// increases in `indexingSizes`.
600/// The implementation iteratively combines increases from contributing using
601/// affine.apply operations.
602/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
603/// provides a gentle portability path for Linalg-like ops with affine maps.
604/// In the future, more general interfaces can be devised to encode similar
605/// shape evolutions and map between an op and its operands.
606SmallVector<OpFoldResult>
607computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
608 AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
609 const PadTilingInterfaceOptions &options);
610
611using PadSizeComputationFunction =
612 std::function<FailureOr<SmallVector<OpFoldResult>>(
613 RewriterBase &, OpOperand &, ArrayRef<Range>,
614 const PadTilingInterfaceOptions &)>;
615
616/// Specific helper for Linalg ops.
617FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
618 RewriterBase &rewriter, OpOperand &operandToPad,
619 ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
620
621/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
622///
623/// * "options.paddingSizes" indicates that each padding dimension should be
624/// padded to the specified padding size.
625/// * "options.padToMultipleOf" indicates that the paddingSizes should be
626// interpreted as the bounding box (dynamic) value to pad to.
627/// * Use "options.paddingValues" to set the padding value of the created
628// tensor::PadOp.
629/// * The tensor::PadOp is returned on success.
630
631FailureOr<TilingInterface>
632rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
633 const PadTilingInterfaceOptions &constOptions,
634 SmallVector<tensor::PadOp> &padOps,
635 PadSizeComputationFunction computePaddingSizeFun =
636 &computeIndexingMapOpInterfacePaddedShape);
637
638namespace detail {
639
640/// Helper struct to hold the results of building a packing loop nest.
641struct PackingResult {
642 SmallVector<OpFoldResult> offsets, sizes, strides;
643 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
644 TransposeOp maybeTransposeOp;
645 tensor::PadOp hoistedPadOp;
646};
647
648/// Build the packing loop nest required to hoist `opToHoist` above
649/// `outermostEnclosingForOp`.
650/// The loop nest is built just before `outermostEnclosingForOp`.
651FailureOr<PackingResult>
652buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist,
653 scf::ForOp outermostEnclosingForOp,
654 ArrayRef<int64_t> transposeVector);
655
656} // namespace detail
657
658/// Mechanically hoist padding operations on tensors by `numLoops` into a new,
659/// generally larger tensor. This achieves packing of multiple padding ops into
660/// a larger tensor. On success, `opToHoist` is replaced by the cloned version
661/// in the packing loop so the caller can continue reasoning about the padding
662/// operation. If `transposeVector` is non-empty, hoist padding introduces a
663/// TransposeOp to transpose the padded tensor before inserting it into the
664/// packed tensor. A `transposeVector` can change the storage order of the
665/// padded tensor but does not change the order of the pack or compute loops.
666///
667/// TODO: In the future, we should consider rewriting as a linalg.pack after
668/// hoisting since this abstraction is now available.
669///
670/// Example in pseudo-mlir:
671/// =======================
672///
673/// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR.
674/// ```
675/// scf.for (%i, %j, %k)
676/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32>
677/// %0 = tensor.pad %st0 low[0, 0] high[...] {
678/// ^bb0( ... ):
679/// linalg.yield %pad
680/// } : tensor<?x?xf32> to tensor<4x8xf32>
681/// compute(%0)
682/// ```
683///
684/// IR resembling the following is produced:
685///
686/// ```
687/// scf.for (%i) {
688/// %packed_init = tensor.empty range(%j) : tensor<?x4x8xf32>
689/// %packed = scf.for (%k) iter_args(%p : %packed_init) {
690/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32>
691/// %0 = tensor.pad %st0 low[0, 0] high[...] {
692/// ^bb0( ... ):
693/// linalg.yield %pad
694/// } : tensor<?x?xf32> to tensor<4x8xf32>
695/// %1 = tensor.insert_slice %0 ...
696/// : tensor<4x8xf32> to tensor<?x4x8xf32>
697/// scf.yield %1: tensor<?x4x8xf32>
698/// } -> tensor<?x4x8xf32>
699/// scf.for (%j, %k) {
700/// %st0 = tensor.extract_slice %packed [%k, 0, 0][1, 4, 8][1, 1, 1] :
701/// tensor<?x4x8xf32> to tensor<4x8xf32>
702/// compute(%st0)
703/// }
704/// }
705/// ```
706FailureOr<Value>
707hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist,
708 int64_t numLoops, ArrayRef<int64_t> transposeVector,
709 tensor::PadOp &hoistedOp,
710 SmallVectorImpl<TransposeOp> &transposeOps);
711/// Calls into `hoistPaddingOnTensors` with a local IRRewriter.
712FailureOr<Value>
713hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops,
714 ArrayRef<int64_t> transposeVector,
715 tensor::PadOp &hoistedOp,
716 SmallVectorImpl<TransposeOp> &transposeOps);
717
718/// Apply padding and hoisting to `linalgOp` according to the configuration
719/// specified in `options`.
720FailureOr<LinalgOp> padAndHoistLinalgOp(RewriterBase &rewriter,
721 LinalgOp linalgOp,
722 const LinalgPaddingOptions &options);
723
724/// Split the given `op` into two parts along the given iteration space
725/// `dimension` at the specified `splitPoint`, and return the two parts.
726/// If the second part is statically known to be empty, do not create it
727/// and return nullptr instead. Error state is signalled by returning
728/// a pair of nullptrs.
729///
730/// For example, the following op:
731///
732/// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>)
733/// outs(%2 : tensor<128x64xf32>)
734///
735/// split along the first dimension at position 42 will result in:
736///
737/// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1]
738/// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1]
739/// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
740/// outs(%5 : tensor<42x64xf32>)
741/// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1]
742///
743/// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1]
744/// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1]
745/// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>)
746/// outs(%8 : tensor<86x64xf32>)
747/// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1]
748///
749/// Note that there is no simplification other than constant propagation applied
750/// to slice extraction and insertion.
751std::pair<TilingInterface, TilingInterface> splitOp(RewriterBase &rewriter,
752 TilingInterface op,
753 unsigned dimension,
754 OpFoldResult splitPoint);
755
756/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
757/// and permute the loop nest according to `interchangeVector`
758/// The permutation is expressed as a list of integers that specify
759/// the new ordering of the loop nest. The length of `interchangeVector`
760/// must be equal to the length of `tileSizes`.
761/// An empty vector is interpreted as the identity permutation and the
762/// transformation returns early.
763///
764/// Return a struct containing the tiled loops in the specified order
765/// and the cloned op if successful, std::nullopt otherwise.
766///
767/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
768/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
769/// integers, in the range 0..`tileSizes.size()` without duplications
770/// (i.e. `[1,1,2]` is an invalid permutation).
771struct TiledLinalgOp {
772 LinalgOp op;
773 SmallVector<Operation *, 8> loops;
774 SmallVector<Value, 4> tensorResults;
775};
776FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
777 const LinalgTilingOptions &options);
778
779/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
780/// the index accesses of `op`. This is an in-place transformation controlled
781/// by `interchangeVector`. An empty vector is interpreted as the identity
782/// permutation and the transformation returns early.
783///
784/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
785/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
786/// integers, in the range 0..`op.rank` without duplications
787/// (i.e. `[1,1,2]` is an invalid permutation).
788///
789/// Return failure if the permutation is not valid.
790FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
791 GenericOp genericOp,
792 ArrayRef<unsigned> interchangeVector);
793
794/// Create a GenericOp from the given named operation `linalgOp` and replace
795/// the given `linalgOp`.
796/// Return failure if `linalgOp` is a GenericOp or misses a region builder.
797FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
798 LinalgOp linalgOp);
799
800/// Create a namedOp from the given GenericOp and replace the GenericOp.
801/// Currently we can specialize only trivial linalg copy operations.
802FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
803 GenericOp genericOp);
804
805/// Create a new buffer using the `allocationFn` provided. The size of this
806/// buffer is either the original subview size when 'useOriginalSubviewSize' is
807/// set to true or the smallest constant bounding size along each dimension that
808/// can be computed for the size of the result of `subView`. Returns the
809/// allocated buffer as `fullLocalView` and the view that matches the size of
810/// the result of subview operation as `partialLocalView`.
811struct PromotionInfo {
812 Value fullLocalView;
813 Value partialLocalView;
814};
815FailureOr<PromotionInfo>
816promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
817 bool useOriginalSubviewSize,
818 const AllocBufferCallbackFn &allocationFn,
819 DataLayout &layout);
820
821/// Promote the `subViews` into a new buffer allocated at the insertion point
822/// `b`. Promotion occurs in 3 steps:
823/// 1. Create a new buffer for a full tile (i.e. not clipped at the
824/// boundary).
825/// 2. Take a full view on the buffer.
826/// 3. Take a partial slice of the full view in step 2. and copy into it.
827///
828/// Return the modified linalg op (the modification happens in place) as well
829/// as all the copy ops created.
830FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
831 const LinalgPromotionOptions &options);
832
833/// Allocate the subview in the GPU workgroup memory.
834std::optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
835 memref::SubViewOp subview,
836 ArrayRef<Value> sizeBounds,
837 DataLayout &);
838
839/// In case of GPU group memory there is no need to deallocate.
840LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/);
841
842/// Create Memref copy operations and add gpu barrier guards before and after
843/// the copy operation to ensure data integrity.
844LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst);
845
846/// Allocate the subview in the GPU private memory.
847std::optional<Value> allocateGPUPrivateMemory(OpBuilder &builder,
848 memref::SubViewOp subview,
849 ArrayRef<Value> sizeBounds,
850 DataLayout &);
851
852/// Normal copy to between src and dst.
853LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst);
854
855/// In case of GPU private memory there is no need to deallocate since the
856/// memory is freed when going outside of the scope.
857LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
858
859/// Return true if there's dedicated logic in the Linalg Vectorizer to
860/// vectorize this Op, false otherwise.
861///
862/// Note that this helper merely implements a very high level check and that the
863/// vectorizer also requires various additional pre-conditions to be met for it
864/// to work (these are checked by the vectorizer itself).
865bool hasVectorizationImpl(Operation *);
866
867/// Transformation information returned after vectorizing.
868struct VectorizationResult {
869 /// Results of the vectorization transform to replace the original operation.
870 SmallVector<Value> replacements;
871};
872/// Returns a `VectorizationResult` containing the results of the vectorized op,
873/// or failure if the transformation fails. If provided, `inputVectorSizes` are
874/// used to vectorize this operation. `inputVectorSizes` must match the rank of
875/// the iteration space of the operation and the input vector sizes must be
876/// greater than or equal to their counterpart iteration space sizes, if static.
877/// `inputVectorShapes` also allows the vectorization of operations with dynamic
878/// shapes.
879FailureOr<VectorizationResult>
880vectorize(RewriterBase &rewriter, Operation *op,
881 ArrayRef<int64_t> inputVectorSizes = {},
882 ArrayRef<bool> inputScalableVecDims = {},
883 bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
884
885/// Emit a suitable vector form for a Copy op with fully static shape.
886LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
887
888/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
889FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter,
890 LinalgOp linalgOp);
891
892/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
893FailureOr<LinalgLoops> linalgOpToParallelLoops(RewriterBase &rewriter,
894 LinalgOp linalgOp);
895
896/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
897FailureOr<LinalgLoops> linalgOpToAffineLoops(RewriterBase &rewriter,
898 LinalgOp linalgOp);
899
900/// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
901/// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument
902/// has one entry per surrounding loop. It uses zero as the convention that a
903/// particular loop is not tiled. This convention simplifies implementations
904/// by avoiding affine map manipulations. The returned ranges correspond to
905/// the loop ranges, in the proper order, that are tiled and for which new
906/// loops will be created. Also the function returns a map from loop indices
907/// of the LinalgOp to the corresponding non-empty range indices of newly
908/// created loops.
909using LoopIndexToRangeIndexMap = DenseMap<int, int>;
910std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
911makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
912 ArrayRef<OpFoldResult> allShapeSizes,
913 ArrayRef<OpFoldResult> allTileSizes);
914
915namespace detail {
916template <typename T>
917struct MultiSizeSpecificationBase {
918 /// Tile sizes.
919 T lowTileSize, highTileSize;
920 /// Number of tiles associated with each size.
921 T lowTripCount, highTripCount;
922};
923
924template <typename T>
925struct ContinuousTileSizeSpecificationBase {
926 /// Tile sizes.
927 SmallVector<T> tileSizes;
928 /// Number of tiles associated with each size.
929 SmallVector<T> tripCounts;
930};
931
932} // namespace detail
933
934/// A description of a multi-size tiling comprising tile sizes and numbers of
935/// tiles, expressed as Values which may or may not be constant. Multi-size
936/// currently means two-size.
937struct MultiSizeSpecification
938 : public detail::MultiSizeSpecificationBase<Value> {};
939struct StaticMultiSizeSpecification
940 : public detail::MultiSizeSpecificationBase<int64_t> {};
941
942struct ContinuousTileSizeSpecification
943 : public detail::ContinuousTileSizeSpecificationBase<Value> {};
944struct StaticContinuousTileSizeSpecification
945 : public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
946
947/// Emits the IR computing the multi-sized tiling specification with two tile
948/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
949/// that there exist numbers of tiles with these sizes that fully cover the
950/// given iteration space `dimension` of the structured `op`.
951///
952/// The computation is as follows:
953///
954/// b = originalTripCount floordiv sizeDivisor
955/// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor
956/// d = (b + t - 1) floordiv t
957/// s = (b floordiv d) * sizeDivisor
958/// v = b % d
959/// u = d - v
960///
961/// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of
962/// the corresponding tiles are `u` and `v`, respectively. Alternatively,
963///
964/// s * u + (s + sizeDivisor) * v == original size,
965/// where s mod sizeDivisor = 0.
966///
967/// Expects all values to be positive. In some cases with the target tile size
968/// sufficiently close to the dimension shape and non-unit divisor, it is
969/// impossible to compute such sizes. If `emitAssertion` is set, also emit the
970/// assertion that size computation succeeded.
971///
972/// Returns the specification consisting of both tile values and the number of
973/// tiles of each size.
974FailureOr<MultiSizeSpecification>
975computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
976 OpFoldResult targetSize, OpFoldResult divisor,
977 bool emitAssertions = true);
978FailureOr<StaticMultiSizeSpecification>
979computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
980 int64_t divisor);
981
982FailureOr<StaticContinuousTileSizeSpecification>
983computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
984 unsigned targetSize);
985FailureOr<ContinuousTileSizeSpecification>
986computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
987 unsigned dimension, OpFoldResult targetSize,
988 bool emitAssertions);
989
990/// Transformation information returned after reduction tiling.
991struct ForallReductionTilingResult {
992 /// The partial reduction tiled op generated.
993 SmallVector<Operation *> parallelTiledOps;
994 /// The final reduction operation merging all the partial reductions.
995 SmallVector<Operation *> mergeOps;
996 /// Initial values used for partial reductions.
997 SmallVector<Value> initialValues;
998 /// The `scf.forall` operation that iterate over the tiles.
999 scf::ForallOp loops;
1000};
1001
1002/// Method to tile a reduction to parallel iterations computing partial
1003/// reductions. After the loop all the partial reduction are merged into a final
1004/// reduction. For example for the following sequence
1005///
1006/// ```mlir
1007/// %0 = linalg.generic %in ["parallel", "reduction"]
1008/// : tensor<7x9xf32> -> tensor<7xf32>
1009/// ```
1010///
1011/// into:
1012///
1013/// ```mlir
1014/// %0 = linalg.fill ... : tensor<7x4xf32>
1015/// %1 = scf.forall (%iv) in (%c4) shared_outs(%arg0 = %0)
1016/// -> (tensor<7x4xf32>) {
1017/// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32>
1018/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
1019/// %4 = linalg.generic %2, %3 ["parallel", "reduction"]
1020/// : tensor<7x?xf32> -> tensor<7xf32>
1021/// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32>
1022/// }
1023/// %6 = linalg.generic %1 ["parallel", "reduction"]
1024/// : tensor<7x4xf32> -> tensor<7xf32>
1025/// ```
1026FailureOr<ForallReductionTilingResult>
1027tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op,
1028 ArrayRef<OpFoldResult> numThreads,
1029 ArrayRef<OpFoldResult> tileSizes = {},
1030 std::optional<ArrayAttr> mapping = std::nullopt);
1031
1032/// All indices returned by IndexOp should be invariant with respect to
1033/// tiling. Therefore, if an operation is tiled, we have to transform the
1034/// indices accordingly, i.e. offset them by the values of the corresponding
1035/// induction variables that are captured implicitly in the body of the op.
1036///
1037/// Example. `linalg.generic` before tiling:
1038///
1039/// #id_2d = (i, j) -> (i, j)
1040/// #pointwise_2d_trait = {
1041/// indexing_maps = [#id_2d, #id_2d],
1042/// iterator_types = ["parallel", "parallel"]
1043/// }
1044/// linalg.generic #pointwise_2d_trait %operand, %result {
1045/// ^bb0(%operand_in: f32, %result_in: f32):
1046/// %i = linalg.index 0 : index
1047/// %j = linalg.index 1 : index
1048/// <some operations that use %i, %j>
1049/// }: memref<50x100xf32>, memref<50x100xf32>
1050///
1051/// After tiling pass with tiles sizes 10 and 25:
1052///
1053/// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
1054///
1055/// %c1 = arith.constant 1 : index
1056/// %c0 = arith.constant 0 : index
1057/// %c25 = arith.constant 25 : index
1058/// %c10 = arith.constant 10 : index
1059/// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
1060/// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
1061/// scf.for %k = %c0 to operand_dim_0 step %c10 {
1062/// scf.for %l = %c0 to operand_dim_1 step %c25 {
1063/// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
1064/// : memref<50x100xf32> to memref<?x?xf32, #strided>
1065/// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1]
1066/// : memref<50x100xf32> to memref<?x?xf32, #strided>
1067/// linalg.generic pointwise_2d_trait %4, %5 {
1068/// ^bb0(%operand_in: f32, %result_in: f32):
1069/// %i = linalg.index 0 : index
1070/// %j = linalg.index 1 : index
1071/// // Indices `k` and `l` are implicitly captured in the body.
1072/// %transformed_i = arith.addi %i, %k : index // index `i` is offset by
1073/// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset
1074/// by %l
1075/// // Every use of %i, %j is replaced with %transformed_i,
1076/// %transformed_j <some operations that use %transformed_i,
1077/// %transformed_j>
1078/// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
1079/// }
1080/// }
1081///
1082/// TODO: Investigate whether mixing implicit and explicit indices
1083/// does not lead to losing information.
1084void transformIndexOps(RewriterBase &b, LinalgOp op,
1085 SmallVectorImpl<Value> &ivs,
1086 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex);
1087
1088/// Apply transformation to split the single linalg op reduction into a
1089/// parallel and reduction dimension. Then create a new linalg.generic op
1090/// doing the rest of the reduction. Return the new linalg op with an extra
1091/// parallel dimension or failure if the transformation didn't happen.
1092///
1093/// Example:
1094/// ```
1095/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
1096/// affine_map<(d0) -> ()>],
1097/// iterator_types = ["reduction"]}
1098/// ins(%in : tensor<32xf32>)
1099/// outs(%out : tensor<f32>) {
1100/// ^bb0(%arg1: f32, %arg2: f32):
1101/// %y = arith.addf %arg1, %arg2 : f32
1102/// linalg.yield %y : f32
1103/// } -> tensor<f32>
1104/// ```
1105/// To:
1106/// ```
1107/// %cst = arith.constant 0.000000e+00 : f32
1108/// %0 = tensor.expand_shape %in [[0, 1]]: tensor<32xf32> into tensor<4x8xf32>
1109/// %1 = tensor.empty [4] : tensor<4xf32>
1110/// %2 = linalg.fill ins(%cst : f32)
1111/// outs(%1 : tensor<4xf32>) -> tensor<4xf32>
1112/// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1113/// affine_map<(d0, d1) -> (d0)>],
1114/// iterator_types = ["parallel", "reduction"]}
1115/// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
1116/// ^bb0(%arg3: f32, %arg5: f32):
1117/// %5 = arith.addf %arg3, %arg4 : f32
1118/// linalg.yield %5 : f32
1119/// } -> tensor<4xf32>
1120/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
1121/// affine_map<(d0) -> ()>],
1122/// iterator_types = ["reduction"]}
1123/// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
1124/// ^bb0(%arg3: f32, %arg4: f32):
1125/// %5 = arith.addf %arg3, %arg4 : f32
1126/// linalg.yield %5 : f32
1127/// } -> tensor<f32>
1128/// ```
1129struct SplitReductionResult {
1130 Operation *initOrAlloc;
1131 FillOp fillOp;
1132 LinalgOp splitLinalgOp;
1133 LinalgOp resultCombiningLinalgOp;
1134};
1135FailureOr<SplitReductionResult>
1136splitReduction(RewriterBase &b, LinalgOp op,
1137 const ControlSplitReductionFn &controlSplitReductionFn,
1138 bool useAlloc = false);
1139
1140/// Scaling-based implementation of the split reduction transformation.
1141/// Instead of introducing an ExpandShapeOp, this rewrites a reduction
1142/// dimension `k` into `k * scale + kk`.
1143///
1144/// Example:
1145/// ```
1146/// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
1147/// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
1148/// ```
1149///
1150/// Is transformed to:
1151///
1152/// ```
1153/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)>
1154/// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)>
1155/// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
1156/// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1157/// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
1158/// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
1159/// %0 = tensor.empty [16, 32, 64] : tensor<16x32x64xf32>
1160/// %cst = arith.constant 0.000000e+00 : f32
1161/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) ->
1162/// tensor<16x32x64xf32>
1163/// %2 = tensor.empty [64, 4] : tensor<64x4xi1>
1164///
1165/// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3],
1166/// iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
1167/// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>,
1168/// tensor<64x4xi1>)
1169/// outs(%1 : tensor<16x32x64xf32>) {
1170/// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32):
1171/// %5 = arith.mulf %arg3, %arg4 : f32
1172/// %6 = arith.addf %arg6, %5 : f32
1173/// linalg.yield %6 : f32
1174/// } -> tensor<16x32x64xf32>
1175///
1176/// %4 = linalg.generic {indexing_maps = [#map4, #map5],
1177/// iterator_types = ["parallel", "parallel", "reduction"]}
1178// ins(%3 : tensor<16x32x64xf32>)
1179/// outs(%C : tensor<16x32xf32>) {
1180/// ^bb0(%arg3: f32, %arg4: f32):
1181/// %5 = arith.addf %arg3, %arg4 : f32
1182/// linalg.yield %5 : f32
1183/// } -> tensor<16x32xf32>
1184///
1185/// return %4 : tensor<16x32xf32>
1186/// ```
1187FailureOr<SplitReductionResult>
1188splitReductionByScaling(RewriterBase &b, LinalgOp op,
1189 const ControlSplitReductionFn &controlSplitReductionFn,
1190 bool useAlloc = false);
1191
1192/// Return `true` if a given sequence of dimensions are contiguous in the
1193/// range of the specified indexing map.
1194bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
1195/// Return `true` if all sequences of dimensions specified in `dimSequences` are
1196/// contiguous in all the ranges of the `maps`.
1197bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
1198 ArrayRef<ReassociationIndices> dimSequences);
1199
1200struct CollapseResult {
1201 SmallVector<Value> results;
1202 LinalgOp collapsedOp;
1203};
1204
1205/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
1206/// to calling this method is that for each list in `foldedIterationDim`, the
1207/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
1208/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
1209/// When valid, the method also collapses the operands of the op. Returns
1210/// replacement values of the results of the original `linalgOp` by inserting
1211/// reshapes to get back values of compatible types.
1212FailureOr<CollapseResult>
1213collapseOpIterationDims(LinalgOp op,
1214 ArrayRef<ReassociationIndices> foldedIterationDims,
1215 RewriterBase &rewriter);
1216
1217struct LowerPackResult {
1218 tensor::PadOp padOp;
1219 tensor::ExpandShapeOp expandShapeOp;
1220 linalg::TransposeOp transposeOp;
1221};
1222
1223/// Rewrite pack as pad + reshape + transpose.
1224FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
1225 linalg::PackOp packOp,
1226 bool lowerPadLikeWithInsertSlice = true);
1227
1228struct LowerUnPackOpResult {
1229 tensor::EmptyOp emptyOp;
1230 linalg::TransposeOp transposeOp;
1231 tensor::CollapseShapeOp collapseShapeOp;
1232 tensor::ExtractSliceOp extractSliceOp;
1233};
1234
1235/// Rewrite pack as empty + transpose + reshape + extract_slice.
1236FailureOr<LowerUnPackOpResult>
1237lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
1238 bool lowerUnpadLikeWithExtractSlice = true);
1239
1240/// Struct to hold the result of a `pack` call.
1241struct PackResult {
1242 SmallVector<linalg::PackOp> packOps;
1243 linalg::LinalgOp packedLinalgOp;
1244 SmallVector<linalg::UnPackOp> unPackOps;
1245};
1246/// Implement packing of a single LinalgOp by `packedSizes`.
1247/// There must be one packedSizes entry per `linalgOp` iterator.
1248/// Return the packed Linalg op on success, failure otherwise.
1249FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
1250 ArrayRef<OpFoldResult> packedSizes);
1251
1252/// Struct to hold the result of a `packTranspose` call.
1253struct PackTransposeResult {
1254 linalg::PackOp transposedPackOp;
1255 linalg::LinalgOp transposedLinalgOp;
1256 linalg::UnPackOp transposedUnPackOp;
1257};
1258/// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the
1259/// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements.
1260/// Return failure if either:
1261/// 1. the `packOp` does not have the `linalgOp` as its unique use.
1262/// 2. the `maybeUnPackOp`, if specified must be a consumer of the result tied
1263/// to the unique `packOp` use.
1264/// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of
1265/// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty.
1266FailureOr<PackTransposeResult>
1267packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
1268 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
1269 ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);
1270
1271/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
1272/// and n are proper parallel dimensions and k is a proper reduction
1273/// dimension. Packing occurs by rewriting the op as a linalg.generic and
1274/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
1275/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
1276/// to reorder {m, n, k} into one of the 8 possible forms. The outer
1277/// dimensions of the operands are not permuted at this time, this is left for
1278/// future work.
1279FailureOr<PackResult>
1280packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
1281 ArrayRef<OpFoldResult> mnkPackedSizes,
1282 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
1283 ArrayRef<int64_t> mnkOrder);
1284
1285struct BlockPackMatmulOptions {
1286 /// Minor block factors (mb, nb, kb) for packing relayout where mb, mn are
1287 /// the parallel dimensions and kb is the reduction dimension.
1288 SmallVector<int64_t, 3> blockFactors;
1289
1290 /// If true, allows packing of dimensions that only partially fit into the
1291 /// block factors.
1292 bool allowPadding = true;
1293
1294 /// Next multiples of the packing sizes.
1295 SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
1296
1297 /// Permutation of matmul (M, N, K) dimensions order.
1298 SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
1299
1300 /// Transpose LHS outer block layout [MB][KB] -> [KB][MB].
1301 bool lhsTransposeOuterBlocks = false;
1302
1303 /// Transpose LHS inner block layout [mb][kb] -> [kb][mb].
1304 bool lhsTransposeInnerBlocks = false;
1305
1306 /// Transpose RHS outer block layout [KB][NB] -> [NB][KB].
1307 bool rhsTransposeOuterBlocks = true;
1308
1309 /// Transpose RHS inner block layout [kb][nb] -> [nb][kb].
1310 bool rhsTransposeInnerBlocks = true;
1311};
1312
1313/// Function type which is used to control matmul packing.
1314/// It is expected to return valid packing configuration for each operation.
1315/// Lack of packing options indicates that no valid configuration could be
1316/// assigned and the operation will not be packed.
1317using ControlBlockPackMatmulFn =
1318 std::function<std::optional<BlockPackMatmulOptions>(linalg::LinalgOp)>;
1319
1320/// Pack a matmul operation into blocked 4D layout.
1321///
1322/// Relayout a matmul operation into blocked layout with two levels of
1323/// subdivision:
1324/// - major 2D blocks - outer dimensions, consist of minor blocks
1325/// - minor 2D blocks - inner dimensions, consist of scalar elements
1326///
1327/// A 2D matmul MxNxK gets reshaped into blocked 4D representation
1328/// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
1329/// where the (MB, NB, KB) dimensions represent the major blocks,
1330/// and the (mb, nb, kb) are the minor blocks of their respective
1331/// original 2D dimensions (M, N, K).
1332///
1333/// Depending on the initial operands' data layout and the specified
1334/// packing options, the major blocks dimensions might get transposed
1335/// e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
1336/// e.g., [mb][kb] -> [kb][mb].
1337/// Any present batch dimensions remain unchanged.
1338/// The final result is unpacked back to the original shape.
1339///
1340/// Return failure if no valid packing options are provided.
1341FailureOr<PackResult>
1342blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
1343 const ControlBlockPackMatmulFn &controlPackMatmul);
1344
1345/// Rewrite tensor.from_elements to linalg.generic.
1346FailureOr<Operation *>
1347rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1348 tensor::FromElementsOp fromElementsOp);
1349
1350/// Rewrite tensor.generate to linalg.generic.
1351FailureOr<Operation *>
1352rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1353 tensor::GenerateOp generateOp);
1354
1355/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice.
1356FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1357 tensor::PadOp padOp);
1358
1359/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
1360/// and linalg.matmul.
1361///
1362/// A convolution operation can be written as a matrix-matrix multiplication by
1363/// unfolding the cross-correlation between input and filter and explicitly copy
1364/// overlapped sliding window inputs.
1365///
1366/// Consider 2D input X with single channel input and output and 2x2 filter W:
1367/// [x(0, 0) , x(0, 1) , ..., x(0, n) ]
1368/// [x(1, 0) , x(1, 1) , ..., x(1, n) ]
1369/// [. , . ,. , . ] [w(0, 0), w(0, 1)]
1370/// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)]
1371/// [. , . , ., . ]
1372/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
1373///
1374/// The packed input data (img2col) is a matrix with |rows| = output spatial
1375/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need
1376/// to calculate the dot product between filter window at input X(x, y)) and the
1377/// filter which will look like the following where r.h.s is the img2col matrix
1378/// and l.h.s is the flattened filter:
1379///
1380/// [x(0,0), x(0,1), x(1,0), x(1,1)]
1381/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
1382/// [x(0,1), x(1,1), x(0,2), x(1,2)]
1383/// [ . , . , . , . ]
1384///
1385/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
1386/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
1387/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
1388/// the N input. For the case where N > 1 its a batched matrix-matrix
1389/// multiplication.
1390///
1391/// On success, return both the operation that produces the img2col tensor and
1392/// the final operation of the sequence that replaces the original convolution.
1393FailureOr<std::pair<Operation *, Operation *>>
1394rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
1395
1396/// Same as the above but for Fhwc channel orderings in the filter. In this case
1397/// the matrix multiplication is actually a row-wise dot-product rather than a
1398/// row-column dot-product. This is to avoid transposing the filter matrix which
1399/// would be required for a regular matrix multiplication to produce the correct
1400/// output dimensions.
1401FailureOr<std::pair<Operation *, Operation *>>
1402rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp);
1403
1404/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
1405/// reduction among the input channels so each convolution can be a
1406/// matrix-vector product and by transposing both input filter so channels are
1407/// outer most the computation is a batched matrix-vector product.
1408FailureOr<std::pair<Operation *, Operation *>>
1409rewriteInIm2Col(RewriterBase &rewriter,
1410 linalg::DepthwiseConv2DNhwcHwcOp convOp);
1411
1412/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the
1413/// channels are to the left of the image shape dimensions, the position of the
1414/// contraction dimension in the resulting matmul is reversed. This swaps the
1415/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) *
1416/// (C x Kh x Kw, Ho x Wo))
1417FailureOr<std::pair<Operation *, Operation *>>
1418rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp);
1419
1420/// Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by
1421/// materializing transpose.
1422FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1423 linalg::Conv2DNhwcFhwcOp op);
1424FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1425 linalg::Conv2DNhwcFhwcQOp op);
1426
1427/// Convert Linalg matmul ops to transposed variants.
1428FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
1429 linalg::MatmulOp op,
1430 bool transposeLHS = true);
1431FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
1432 linalg::BatchMatmulOp op,
1433 bool transposeLHS = true);
1434
1435/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
1436/// F(m x m, r x r). m is the dimension size of output and r is the dimension
1437/// size of filter.
1438FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1439 linalg::Conv2DNhwcFhwcOp op,
1440 WinogradConv2DFmr fmr);
1441
1442/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
1443/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1444/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
1445/// the rewriting, we get
1446///
1447/// scf.for %f = lo_f to hi_f step 1
1448/// scf.for %c = lo_c to hi_c step 1
1449/// %extracted = extract filter<h x w> from filter<f x h x w x c>
1450/// %ret = linalg.matmul G, %extracted
1451/// %ret = linalg.matmul %ret, GT
1452/// %inserted = insert %ret into filter<h x w x c x f>
1453FailureOr<Operation *>
1454decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1455 linalg::WinogradFilterTransformOp op);
1456
1457/// Rewrite linalg.winograd_input_transform. The data layout of the input is
1458/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1459/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
1460/// and tileW. After the rewriting, we get
1461///
1462/// scf.for %h = 0 to tileH step 1
1463/// scf.for %w = 0 to tileW step 1
1464/// scf.for %n = 0 to N step 1
1465/// scf.for %c = 0 to C step 1
1466/// %extracted = extract %extracted<alphaH x alphaW> from
1467/// %input<N x H x W x C>
1468/// at [%n, (%h x m), (%w x m), %c]
1469/// %ret = linalg.matmul BT, %extracted
1470/// %ret = linalg.matmul %ret, B
1471/// %inserted = insert %ret<alphaH x alphaW> into
1472/// %output<alphaH x alphaW x tileH x tileW x N x C>
1473/// at [0, 0, %h, %w, %n, %c]
1474FailureOr<Operation *>
1475decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1476 linalg::WinogradInputTransformOp op);
1477
1478/// Rewrite linalg.winograd_output_transform. The data layout of the output is
1479/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1480/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
1481/// and tileW. After the transformation, we get
1482///
1483/// scf.for %h = 0 to tileH step 1
1484/// scf.for %w = 0 to tileW step 1
1485/// scf.for %n = 0 to N step 1
1486/// scf.for %f = 0 to F step 1
1487/// %extracted = extract %extracted<alphaH x alphaW> from
1488/// %input<alphaH x alphaW x tileH x tileW x N x F>
1489/// at [0, 0, %h, %w, %n, %f]
1490/// %ret = linalg.matmul AT, %extracted
1491/// %ret = linalg.matmul %ret, A
1492/// %inserted = insert %ret<alphaH x alphaW> into
1493/// output<N x H x W x F>
1494/// at [%n, (%h x m), (%w x m), %f]
1495FailureOr<Operation *>
1496decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1497 linalg::WinogradOutputTransformOp op);
1498
1499/// Method to deduplicate operands and remove dead results of `linalg.generic`
1500/// operations. This is effectively DCE for a linalg.generic op. If there is
1501/// deduplication of operands orremoval of results, replaces the `genericOp`
1502/// with a new op and returns it. Returns the same operation if there is no
1503/// deduplication/removal.
1504FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
1505 RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs);
1506
1507//===----------------------------------------------------------------------===//
1508// Rewrite patterns wrapping transformations.
1509// TODO: every single such pattern should be a close to noop wrapper around a
1510// functional-stye API call.
1511//===----------------------------------------------------------------------===//
1512
1513/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
1514/// convolution ops.
1515template <typename Conv2DOp, typename Conv1DOp>
1516struct DownscaleSizeOneWindowed2DConvolution final
1517 : public OpRewritePattern<Conv2DOp> {
1518 using OpRewritePattern<Conv2DOp>::OpRewritePattern;
1519
1520 FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
1521 PatternRewriter &rewriter) const;
1522
1523 LogicalResult matchAndRewrite(Conv2DOp convOp,
1524 PatternRewriter &rewriter) const override {
1525 return returningMatchAndRewrite(convOp, rewriter);
1526 }
1527};
1528
1529extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1530 Conv1DNwcWcfOp>;
1531extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1532 Conv1DNcwFcwOp>;
1533
1534/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1535/// dimensions into 1-D depthwise convolution ops.
1536struct DownscaleDepthwiseConv2DNhwcHwcOp final
1537 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1538 DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
1539 PatternBenefit benefit = 1)
1540 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
1541
1542 FailureOr<DepthwiseConv1DNwcWcOp>
1543 returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1544 PatternRewriter &rewriter) const;
1545
1546 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1547 PatternRewriter &rewriter) const override {
1548 return returningMatchAndRewrite(convOp, rewriter);
1549 }
1550};
1551
1552struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
1553 DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
1554 : OpRewritePattern<Conv2DOp>(context, benefit) {}
1555
1556 FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
1557 PatternRewriter &rewriter) const;
1558
1559 LogicalResult matchAndRewrite(Conv2DOp convOp,
1560 PatternRewriter &rewriter) const override {
1561 return returningMatchAndRewrite(convOp, rewriter);
1562 }
1563};
1564
1565///
1566/// Linalg generalization pattern.
1567///
1568/// Apply the `generalization` transformation as a pattern.
1569/// See `generalization` for more details.
1570//
1571// TODO: Automatic default pattern class that just unwraps a function
1572// returning FailureOr<GenericOp>.
1573struct LinalgGeneralizationPattern
1574 : public OpInterfaceRewritePattern<LinalgOp> {
1575 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1576
1577 /// `matchAndRewrite` implementation that returns the significant
1578 /// transformed pieces of IR.
1579 FailureOr<GenericOp>
1580 returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
1581 return generalizeNamedOp(rewriter, linalgOp: op);
1582 }
1583
1584 LogicalResult matchAndRewrite(LinalgOp op,
1585 PatternRewriter &rewriter) const override {
1586 return returningMatchAndRewrite(op, rewriter);
1587 }
1588};
1589
1590struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> {
1591 using OpRewritePattern<GenericOp>::OpRewritePattern;
1592
1593 FailureOr<GenericOp>
1594 returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
1595 return specializeGenericOp(rewriter, genericOp: op);
1596 }
1597
1598 LogicalResult matchAndRewrite(GenericOp op,
1599 PatternRewriter &rewriter) const override {
1600 return returningMatchAndRewrite(op, rewriter);
1601 }
1602};
1603
1604/// Vectorization pattern for memref::CopyOp.
1605struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
1606 using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
1607
1608 LogicalResult matchAndRewrite(memref::CopyOp copyOp,
1609 PatternRewriter &rewriter) const override;
1610};
1611
1612using OptimizeCopyFn =
1613 std::function<LogicalResult(RewriterBase &, tensor::PadOp, Value)>;
1614
1615/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
1616/// InsertSliceOp. For now, only constant padding values are supported.
1617struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
1618 DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
1619 : OpRewritePattern<tensor::PadOp>(context, benefit) {}
1620 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1621 PatternRewriter &rewriter) const override;
1622
1623protected:
1624 Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
1625 Value dest,
1626 const SmallVector<Value> &dynSizes) const;
1627};
1628
1629/// Rewrites a linalg::PackOp into a sequence of:
1630/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
1631/// tensor::InsertSliceOp ops.
1632///
1633/// Requires that all the outer dims of the input linalg::PackOp are 1.
1634///
1635/// Before:
1636/// ```
1637/// %packed = linalg.pack %input
1638/// padding_value(%pad : f32)
1639/// inner_dims_pos = [1, 0]
1640/// inner_tiles = [2, %high]
1641/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
1642/// ```
1643///
1644/// After:
1645/// ```
1646/// // PadOp
1647/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
1648/// ^bb0(...):
1649/// tensor.yield %arg2 : f32
1650/// } : tensor<5x1xf32> to tensor<?x2xf32>
1651/// // EmptyOp + TransposeOp
1652/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
1653/// %transposed = linalg.transpose
1654/// ins(%extracted_slice : tensor<?x2xf32>)
1655/// outs(%empty : tensor<2x?xf32>)
1656/// permutation = [1, 0]
1657/// // InsertSliceOp
1658/// %inserted_slice = tensor.insert_slice %transposed
1659/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
1660/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
1661/// ```
1662struct DecomposeOuterUnitDimsPackOpPattern
1663 : public OpRewritePattern<linalg::PackOp> {
1664 using OpRewritePattern<linalg::PackOp>::OpRewritePattern;
1665 LogicalResult matchAndRewrite(linalg::PackOp packOp,
1666 PatternRewriter &rewriter) const override;
1667};
1668
1669/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
1670/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
1671///
1672/// Requires that all the outer dims of the input linalg::PackOp are 1.
1673///
1674/// Before:
1675/// ```
1676/// %packed = linalg.unpack %input
1677/// inner_dims_pos = [1, 0]
1678/// inner_tiles = [2, 8]
1679/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
1680/// ```
1681///
1682/// After:
1683/// ```
1684/// // Rank-reduced extract to obtain the tile
1685/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1]
1686/// : tensor<1x1x2x8xf32> to tensor<2x8xf32>
1687/// // EmptyOp + TransposeOp
1688/// %init = tensor.empty() : tensor<8x2xf32>
1689/// %transposed = linalg.transpose
1690/// ins(%extracted_slice : tensor<2x8xf32>)
1691/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0]
1692/// // Extract a slice matching the specified output size
1693/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1]
1694/// : tensor<8x2xf32> to tensor<5x1xf32>
1695/// ```
1696struct DecomposeOuterUnitDimsUnPackOpPattern
1697 : public OpRewritePattern<linalg::UnPackOp> {
1698 using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern;
1699 LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp,
1700 PatternRewriter &rewriter) const override;
1701};
1702
1703/// Match and rewrite for the pattern:
1704/// ```
1705/// %alloc = ...
1706/// [optional] %view = memref.view %alloc ...
1707/// %subView = subview %allocOrView ...
1708/// [optional] linalg.fill(%allocOrView, %cst) ...
1709/// ...
1710/// memref.copy(%in, %subView) ...
1711/// vector.transfer_read %allocOrView[...], %cst ...
1712/// ```
1713/// into
1714/// ```
1715/// [unchanged] %alloc = ...
1716/// [unchanged] [optional] %view = memref.view %alloc ...
1717/// [unchanged] [unchanged] %subView = subview %allocOrView ...
1718/// ...
1719/// vector.transfer_read %in[...], %cst ...
1720/// ```
1721/// Where there is no interleaved use between memref.copy and transfer_read as
1722/// well as no interleaved use between linalg.fill and memref.copy (if
1723/// linalg.fill is specified).
1724/// This is a custom rewrite to forward partial reads (with optional fills) to
1725/// vector.transfer_read.
1726struct LinalgCopyVTRForwardingPattern
1727 : public OpRewritePattern<vector::TransferReadOp> {
1728 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
1729
1730 LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
1731 PatternRewriter &rewriter) const override;
1732};
1733
1734/// Match and rewrite for the pattern:
1735/// ```
1736/// %alloc = ...
1737/// [optional] %view = memref.view %alloc ...
1738/// %subView = subview %allocOrView...
1739/// ...
1740/// vector.transfer_write %..., %allocOrView[...]
1741/// memref.copy(%subView, %out)
1742/// ```
1743/// into
1744/// ```
1745/// [unchanged] %alloc = ...
1746/// [unchanged] [optional] %view = memref.view %alloc ...
1747/// [unchanged] %subView = subview %allocOrView...
1748/// ...
1749/// vector.transfer_write %..., %out[...]
1750/// ```
1751/// Where there is no interleaved use between transfer_write and memref.copy.
1752/// This is a custom rewrite to forward partial writes to
1753/// vector.transfer_write.
1754struct LinalgCopyVTWForwardingPattern
1755 : public OpRewritePattern<vector::TransferWriteOp> {
1756 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
1757
1758 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1759 PatternRewriter &rewriter) const override;
1760};
1761
1762/// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)).
1763struct ExtractSliceOfPadTensorSwapPattern
1764 : public OpRewritePattern<tensor::ExtractSliceOp> {
1765 /// A function to control pattern application and rewrite logic.
1766 ///
1767 /// The function will be given the slice op and should return:
1768 /// - std::nullopt: to fail the match and not apply the pattern;
1769 /// - true: to apply the pattern with zero slice guard;
1770 /// - false: to apply the pattern without zero slice guard.
1771 ///
1772 /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice
1773 /// guard.
1774 using ControlFn = std::function<std::optional<bool>(tensor::ExtractSliceOp)>;
1775
1776 ExtractSliceOfPadTensorSwapPattern(MLIRContext *context,
1777 ControlFn controlFn = nullptr,
1778 PatternBenefit benefit = 1)
1779 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
1780
1781 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1782 PatternRewriter &rewriter) const override;
1783
1784private:
1785 ControlFn controlFn;
1786};
1787
1788//===----------------------------------------------------------------------===//
1789// Populate functions.
1790//===----------------------------------------------------------------------===//
1791
1792/// Canonicalization patterns relevant to apply after tiling patterns. These
1793/// are applied automatically by the tiling pass but need to be applied
1794/// manually when tiling is called programmatically.
1795RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
1796void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
1797
1798/// Linalg generalization patterns
1799
1800/// Populates `patterns` with patterns to convert spec-generated named ops to
1801/// linalg.generic ops.
1802void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
1803
1804/// Populates `patterns` with patterns to convert linalg.generic ops to named
1805/// ops where possible. A linalg.generic can represent wide range and complex
1806/// computations for which equivalent linalg named op may not exist e.g.
1807/// linalg.generic that takes a tensor and computes a polynomial such as:
1808/// p(x) = an*x^n + ... + a1x + a0
1809/// There is no equivalent named op to convert to. Many such cases exist.
1810void populateLinalgGenericOpsSpecializationPatterns(
1811 RewritePatternSet &patterns);
1812
1813/// Populates `patterns` with patterns that fold operations like
1814/// `linalg.transform` into elementwise op map.
1815void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
1816
1817/// Linalg decompose convolutions patterns
1818
1819/// Populates patterns to decompose high-D convolution ops into low-D ones.
1820/// This is a step in progressive lowering for convolution ops, afterwards we
1821/// can vectorize the low-D convolution ops.
1822void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
1823 PatternBenefit benefit = 1);
1824
1825/// Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
1826/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
1827/// outer dims to be unit.
1828void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);
1829
1830/// Populates patterns to decompose tensor.pad into e.g.
1831/// tensor.empty, linalg.fill, tensor.insert_slice.
1832void populateDecomposePadPatterns(RewritePatternSet &patterns);
1833
1834/// Populates patterns to transform linalg.conv_2d_xxx operations into
1835/// linalg.generic (for img2col packing) and linalg.matmul.
1836/// \see rewriteInIm2Col for more details.
1837void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
1838
1839/// Populates `patterns` with patterns that vectorize tensor.pad.
1840/// These patterns are meant to apply in a complementary fashion. Benefits
1841/// are used to encode a certain ordering of pattern application. To avoid
1842/// scattering magic constants throughout the code base, the patterns must be
1843/// added with this function. `baseBenefit` can be used to offset the benefit
1844/// of all tensor::PadOp vectorization patterns by a certain value.
1845void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
1846 PatternBenefit baseBenefit = 1);
1847
1848/// Populate patterns for splitting a `LinalgOp` with multiple statements within
1849/// its payload into multiple `GenericOp` that have a single statement.
1850/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments
1851/// and results from the generated decomposed ops. This is default `true` since
1852/// the core decomposition patterns relies on these clean up patterns. It is set
1853/// to false only for testing purposes.
1854void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns,
1855 bool removeDeadArgsAndResults = true);
1856
1857/// Populate patterns that convert non-destination-style ops to destination
1858/// style ops.
1859void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns);
1860
1861/// Populate patterns for vectorizing low-D convolution ops. This is a step in
1862/// progressive lowering for convolution ops, it assume high-D convolution ops
1863/// were decomposed previously.
1864void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
1865 PatternBenefit benefit = 1);
1866
1867/// Populate patterns that convert `ElementwiseMappable` ops to linalg
1868/// parallel loops.
1869void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
1870
1871/// Populate patterns that are only useful in the context of sparse tensors.
1872void populateSparseTensorRewriting(RewritePatternSet &patterns);
1873
1874/// Function type which is used to control when to stop fusion. It is expected
1875/// that OpOperand is not modified in the callback. The OpOperand is not marked
1876/// as const to allow callers to use non-const methods.
1877using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
1878
1879/// Patterns for fusing linalg operation on tensors.
1880
1881/// Pattern to fuse `linalg.generic` -> `linalg.generic` operations
1882/// when both operations are fusable elementwise operations.
1883void populateElementwiseOpsFusionPatterns(
1884 RewritePatternSet &patterns,
1885 const ControlFusionFn &controlElementwiseOpFusion);
1886
1887/// Function type which is used to control propagation of linalg.pack/unpack
1888/// ops.
1889using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
1890
1891/// Patterns to bubble up or down data layout ops across other operations.
1892void populateDataLayoutPropagationPatterns(
1893 RewritePatternSet &patterns,
1894 const ControlPropagationFn &controlPackUnPackPropagation);
1895
1896/// Pattern to remove dead operands and results of `linalg.generic` operations.
1897/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
1898void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
1899
1900/// Patterns to promote inputs to outputs and remove unused inputs of
1901/// `linalg.generic` ops.
1902void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
1903
1904/// Function type to control generic op dimension collapsing. It is expected
1905/// to return an array of `ReassociationIndices` representing dimensions that
1906/// should be merged.
1907using GetCollapsableDimensionsFn =
1908 std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
1909
1910/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
1911/// tensor operands when needed and expand back the result tensors.
1912void populateCollapseDimensions(
1913 RewritePatternSet &patterns,
1914 const GetCollapsableDimensionsFn &controlCollapseDimensions);
1915
1916/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
1917/// producer (consumer) generic operation by expanding the dimensionality of the
1918/// loop in the generic op.
1919void populateFoldReshapeOpsByExpansionPatterns(
1920 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes);
1921
1922/// Patterns to fold an expanding tensor.expand_shape operation with its
1923/// producer generic operation by collapsing the dimensions of the generic op.
1924void populateFoldReshapeOpsByCollapsingPatterns(
1925 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes);
1926
1927/// Patterns to constant fold Linalg operations.
1928void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
1929 const ControlFusionFn &controlFn);
1930
1931/// Pattern to replace `linalg.add` when destination passing on a contraction op
1932/// suffices for achieving the sum.
1933void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
1934
1935/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
1936/// if the producer is a `linalg` operation with all parallel iterator types.
1937void populateFuseTensorPadWithProducerLinalgOpPatterns(
1938 RewritePatternSet &patterns);
1939
1940/// Patterns to convert from one named op to another. These can be seen as
1941/// canonicalizations of named ops into another named op.
1942void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
1943
1944/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
1945/// tensors via reassociative reshape ops.
1946void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
1947 ControlDropUnitDims &options);
1948
1949/// A pattern that converts init operands to input operands.
1950void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);
1951
1952/// Patterns that are used to inline constant operands into linalg generic ops.
1953void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
1954
1955/// Patterns that are used to bubble up extract slice op above linalg op.
1956void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
1957
1958/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into
1959/// linalg.fill(%cst, tensor.extract_slice(%init)).
1960void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);
1961
1962/// Add patterns to make explicit broadcasts and transforms in the
1963/// input operands of a genericOp.
1964void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns);
1965
1966/// Patterns to apply `splitReduction` below.
1967void populateSplitReductionPattern(
1968 RewritePatternSet &patterns,
1969 const ControlSplitReductionFn &controlSplitReductionFn,
1970 bool useAlloc = false);
1971
1972/// Patterns to convert Linalg matmul ops to transposed variants.
1973void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
1974 bool transposeLHS = true);
1975
1976/// Patterns to block pack Linalg matmul ops.
1977void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
1978 const ControlBlockPackMatmulFn &controlFn);
1979
1980/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
1981void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
1982 WinogradConv2DFmr fmr);
1983
1984/// Patterns to decompose Winograd operators.
1985void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
1986
1987/// Adds patterns that reduce the rank of named contraction ops that have
1988/// unit dimensions in the operand(s) by converting to a sequence of
1989/// `collapse_shape`,
1990/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For
1991/// example a `linalg.batch_matmul` with unit batch size will convert to
1992/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
1993/// convert to a `linalg.dot`.
1994void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
1995
1996/// Function type which is used to control folding operations like `tensor.pad`
1997/// and `tensor.extract_slice` into linalg.pack/unpack ops.
1998using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
1999/// Populates `patterns` with patterns that fold operations like `tensor.pad`
2000/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
2001/// respectively.
2002void populateFoldIntoPackAndUnpackPatterns(
2003 RewritePatternSet &patterns,
2004 const ControlFoldIntoPackUnpackFn &controlFn = nullptr);
2005
2006/// Populates `patterns` with patterns that fold operations like `linalg.pack`
2007/// and `linalg.unpack` into `tensor.empty`.
2008void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns);
2009
2010/// Populates `patterns` with patterns that simplify `tensor.pack` and
2011/// `tensor.unpack` operations.
2012void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
2013
2014} // namespace linalg
2015} // namespace mlir
2016
2017#endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
2018

source code of mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h