1//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/IR/Linalg.h"
10#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11#include "mlir/Dialect/Tensor/IR/Tensor.h"
12#include "mlir/Dialect/Utils/IndexingUtils.h"
13#include "mlir/IR/PatternMatch.h"
14
15namespace mlir {
16namespace linalg {
17namespace {
18
19/// Returns the number of shape sizes that is either dynamic or greater than 1.
20static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21 return llvm::count_if(
22 Range&: shape, P: [](int64_t v) { return ShapedType::isDynamic(dValue: v) || v > 1; });
23}
24
25/// Returns success() if there is only 1 dimension size in non-packed domain
26/// being greater than 1 and packing only happens on the dimension.
27/// Note: this method should only be used by pack/unpack to reshape conversion.
28/// It assumes that non-unit inner tile size must be used by the non-unit
29/// dimension.
30static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31 ArrayRef<int64_t> srcShape,
32 ArrayRef<int64_t> innerPackTileSize) {
33 if (getNumGtOneDims(shape: srcShape) > 1) {
34 return rewriter.notifyMatchFailure(
35 arg&: op, msg: "expects non-packed domain to have at most one non-unit dims");
36 }
37 // Non-unit inner tile size must be used by the non-unit dimension. If not, it
38 // will faill on getting reassociation maps.
39 if (getNumGtOneDims(shape: innerPackTileSize) > 1) {
40 return rewriter.notifyMatchFailure(
41 arg&: op, msg: "expects at most one non-unit inner tiles");
42 }
43 return success();
44}
45
46// If the `linalgOp` represents a transpose, return the permutation vector for
47// the transpose. Otherwise, return failure.
48static FailureOr<SmallVector<int64_t>>
49getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
50 if (auto transposeOp = dyn_cast<linalg::TransposeOp>(Val: linalgOp.getOperation()))
51 return SmallVector<int64_t>(transposeOp.getPermutation());
52 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
53 return failure();
54
55 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
56 return failure();
57 auto mapRange = linalgOp.getIndexingMapsArray();
58 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
59 mapRange.front() == mapRange.back()) {
60 return failure();
61 }
62 if (!llvm::hasSingleElement(C&: linalgOp.getBlock()->getOperations()))
63 return failure();
64 AffineMap outMap = mapRange.back();
65 AffineMap inMap = mapRange.front();
66 // To get the permutation, look at each output index and find which
67 // dimension in the input we're reading from for that index.
68 return llvm::map_to_vector(C: outMap.getResults(),
69 F: [&](AffineExpr expr) -> int64_t {
70 return *inMap.getResultPosition(input: expr);
71 });
72}
73
74/// Packing one-dimensional tensor can be expressed as an expand shape op.
75struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
76 using OpRewritePattern<PackOp>::OpRewritePattern;
77
78 FailureOr<Value>
79 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
80 Type newOperandType,
81 ArrayRef<ReassociationIndices> reassociation) const {
82 if (operand.getType() == newOperandType)
83 return operand;
84 return rewriter
85 .create<tensor::ExpandShapeOp>(location: loc, args&: newOperandType, args&: operand,
86 args&: reassociation)
87 .getResult();
88 }
89
90 /// Returns success() if it is only packing on the innermost dimension.
91 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92 PackOp packOp) const {
93 auto outerDimsPerm = packOp.getOuterDimsPerm();
94 if (!outerDimsPerm.empty() && !isIdentityPermutation(permutation: outerDimsPerm)) {
95 return rewriter.notifyMatchFailure(
96 arg&: packOp,
97 msg: "expects outer_dims_perm is empty or an identity permutation");
98 }
99
100 int64_t srcRank = packOp.getSourceRank();
101 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103 return rewriter.notifyMatchFailure(
104 arg&: packOp, msg: "expects packing at the innermost dimension");
105 }
106 return success();
107 }
108
109 LogicalResult matchAndRewrite(PackOp packOp,
110 PatternRewriter &rewriter) const override {
111 if (packOp.getPaddingValue())
112 return rewriter.notifyMatchFailure(arg&: packOp, msg: "expects no padding value");
113
114 RankedTensorType sourceType = packOp.getSourceType();
115 if (failed(Result: isPackOnInnerMostDim(rewriter, packOp)) &&
116 failed(Result: isPackOn1D(rewriter, op: packOp, srcShape: sourceType.getShape(),
117 innerPackTileSize: packOp.getStaticTiles())) &&
118 !packOp.isLikePad()) {
119 return failure();
120 }
121
122 RankedTensorType destType = packOp.getDestType();
123 auto reassociation =
124 getReassociationIndicesForReshape(sourceType, targetType: destType);
125 if (!reassociation)
126 return failure();
127 FailureOr<Value> expanded =
128 insertExpand(rewriter, loc: packOp.getLoc(), operand: packOp.getSource(), newOperandType: destType,
129 reassociation: *reassociation);
130 if (failed(Result: expanded)) {
131 return rewriter.notifyMatchFailure(
132 arg&: packOp, msg: "unable to expand source of tensor.pack");
133 }
134 rewriter.replaceOp(op: packOp, newValues: *expanded);
135 return success();
136 }
137};
138
139struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
140 using OpRewritePattern<UnPackOp>::OpRewritePattern;
141
142 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143 Type newOperandType, ArrayAttr reassociation) const {
144 if (operand.getType() == newOperandType)
145 return operand;
146 return rewriter.create<tensor::CollapseShapeOp>(location: loc, args&: newOperandType,
147 args&: operand, args&: reassociation);
148 }
149
150 /// Returns success() if it is unpacking on the innermost dimension.
151 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152 UnPackOp unpackOp) const {
153 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
154 if (!outerDimsPerm.empty() && !isIdentityPermutation(permutation: outerDimsPerm)) {
155 return rewriter.notifyMatchFailure(
156 arg&: unpackOp,
157 msg: "expects outer_dims_perm is empty or an identity permutation");
158 }
159
160 RankedTensorType sourceType = unpackOp.getSourceType();
161 RankedTensorType destType = unpackOp.getDestType();
162 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163 return rewriter.notifyMatchFailure(arg&: unpackOp, msg: "expects static shapes");
164
165 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167 return rewriter.notifyMatchFailure(
168 arg&: unpackOp, msg: "expects unpacking on the innermost dimension");
169 }
170
171 return success();
172 }
173
174 LogicalResult matchAndRewrite(UnPackOp unpackOp,
175 PatternRewriter &rewriter) const override {
176 RankedTensorType destType = unpackOp.getDestType();
177 if (failed(Result: isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178 failed(Result: isPackOn1D(rewriter, op: unpackOp, srcShape: destType.getShape(),
179 innerPackTileSize: unpackOp.getStaticTiles())) &&
180 !unpackOp.isLikeUnPad()) {
181 return failure();
182 }
183
184 RankedTensorType sourceType = unpackOp.getSourceType();
185 auto reassociation =
186 getReassociationIndicesForReshape(sourceType, targetType: destType);
187 if (!reassociation)
188 return failure();
189 Value collapsed = insertCollapse(
190 rewriter, loc: unpackOp.getLoc(), operand: unpackOp.getSource(), newOperandType: destType,
191 reassociation: getReassociationIndicesAttribute(b&: rewriter, reassociation: *reassociation));
192 rewriter.replaceOp(op: unpackOp, newValues: collapsed);
193 return success();
194 }
195};
196
197/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198/// the pad op has zero low paddings, or if `pack` has no padding values.
199struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
200public:
201 FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
202 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
203
204 LogicalResult matchAndRewrite(PackOp packOp,
205 PatternRewriter &rewriter) const override {
206 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
207
208 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
209 return failure();
210
211 // User controlled folding function.
212 if (controlFn && !controlFn(&packOp.getSourceMutable()))
213 return failure();
214
215 Value constantPaddingValue = padOp.getConstantPaddingValue();
216 if (!constantPaddingValue)
217 return failure();
218
219 if (auto paddingValue = packOp.getPaddingValue())
220 if (!isEqualConstantIntOrValue(ofr1: paddingValue, ofr2: constantPaddingValue))
221 return failure();
222
223 rewriter.replaceOpWithNewOp<PackOp>(
224 op: packOp, args: padOp.getSource(), args: packOp.getDest(), args: packOp.getInnerDimsPos(),
225 args: packOp.getMixedTiles(), args&: constantPaddingValue,
226 args: packOp.getOuterDimsPerm());
227 return success();
228 }
229
230private:
231 ControlFoldIntoPackUnpackFn controlFn;
232};
233
234/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
235/// has extract_slice semantics.
236struct FoldUnpackWithExtractSliceOp
237 : public OpRewritePattern<tensor::ExtractSliceOp> {
238public:
239 FoldUnpackWithExtractSliceOp(MLIRContext *context,
240 ControlFoldIntoPackUnpackFn controlFn)
241 : OpRewritePattern<tensor::ExtractSliceOp>(context),
242 controlFn(std::move(controlFn)) {}
243
244 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
245 PatternRewriter &rewriter) const override {
246 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
247 if (!unpackOp)
248 return failure();
249
250 // User controlled folding function.
251 if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
252 return failure();
253
254 if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
255 return rewriter.notifyMatchFailure(
256 arg&: sliceOp, msg: "rank-reduced folding is not supported");
257 }
258
259 // Check all offsets are zeros, and all strides are ones.
260 if (!areAllConstantIntValue(ofrs: sliceOp.getMixedOffsets(), value: 0) ||
261 !areAllConstantIntValue(ofrs: sliceOp.getMixedStrides(), value: 1)) {
262 return rewriter.notifyMatchFailure(
263 arg&: sliceOp, msg: "expects offsets to be 0s and strides to be 1s");
264 }
265
266 // Create a new empty output tensor.
267 Type elementType = unpackOp.getDestType().getElementType();
268 Value output = rewriter.create<tensor::EmptyOp>(
269 location: sliceOp.getLoc(), args: sliceOp.getMixedSizes(), args&: elementType);
270 rewriter.replaceOpWithNewOp<UnPackOp>(
271 op: sliceOp, args: unpackOp.getSource(), args&: output, args: unpackOp.getInnerDimsPos(),
272 args: unpackOp.getMixedTiles(), args: unpackOp.getOuterDimsPerm());
273 return success();
274 }
275
276private:
277 ControlFoldIntoPackUnpackFn controlFn;
278};
279
280// Applies 'permutation' on 'inVec' and stores the result in resVec.
281// 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
282// `rank` sets the boundary for permutation i.e., the permutation dim can't be
283// greater than the rank specified. If it's so then return false.
284// For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
285// permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
286// not allowed since `3` exceeds the value of the rank in the given range.
287static bool checkAndPermute(ArrayRef<int64_t> permutation,
288 ArrayRef<int64_t> inVec,
289 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
290
291 for (unsigned int i = 0; i < rank; ++i) {
292 int64_t remappedPosition = permutation[i];
293 if (remappedPosition >= rank)
294 return false;
295 if (!inVec.empty())
296 remappedPosition = inVec[remappedPosition];
297 resVec.push_back(Elt: remappedPosition);
298 }
299
300 return true;
301}
302
303/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
304/// semantics.
305struct FoldProducerPackWithConsumerLinalgTransposeOp
306 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
307
308public:
309 FoldProducerPackWithConsumerLinalgTransposeOp(
310 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
311 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
312 controlFn(std::move(controlFn)) {}
313
314 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
315 PatternRewriter &rewriter) const override {
316 auto packOp = linalgOp->getOperand(idx: 0).getDefiningOp<PackOp>();
317
318 if (!packOp)
319 return failure();
320
321 // User controlled folding function.
322 if (controlFn && !controlFn(&linalgOp->getOpOperand(idx: 0)))
323 return failure();
324
325 FailureOr<SmallVector<int64_t>> maybePerm =
326 getTransposeOpPermutation(linalgOp);
327 if (failed(Result: maybePerm))
328 return failure();
329
330 auto innerDimsPos = packOp.getInnerDimsPos();
331 auto mixedInnerTiles = packOp.getMixedTiles();
332 auto outerDimsPerm = packOp.getOuterDimsPerm();
333 auto transposePerm = maybePerm.value();
334 SmallVector<int64_t> newOuterDimsPermVec;
335 SmallVector<int64_t> newInnerDimsPosVec;
336 SmallVector<OpFoldResult> newMixedInnerTilesVec;
337 int64_t srcRank = packOp.getSourceRank();
338
339 if (!checkAndPermute(permutation: transposePerm, inVec: outerDimsPerm, resVec&: newOuterDimsPermVec,
340 rank: srcRank))
341 return rewriter.notifyMatchFailure(
342 arg&: linalgOp,
343 msg: "Cannot fold in tensor.pack if a tile dimension was transposed "
344 "with a non-tile dimension in linalg.transpose.");
345
346 // Process transpose operation for tiled inner dimensions
347 for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
348 int64_t remappedPosition = transposePerm[i] - srcRank;
349 newMixedInnerTilesVec.push_back(Elt: mixedInnerTiles[remappedPosition]);
350 newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]);
351 }
352
353 Value output = packOp.createDestinationTensor(
354 b&: rewriter, loc: linalgOp.getLoc(), source: packOp.getSource(), innerTileSizes: newMixedInnerTilesVec,
355 innerDimsPos: newInnerDimsPosVec, outerDimsPerm: newOuterDimsPermVec);
356
357 rewriter.replaceOpWithNewOp<PackOp>(
358 op: linalgOp, args: packOp.getSource(), args&: output, args&: newInnerDimsPosVec,
359 args&: newMixedInnerTilesVec, args: packOp.getPaddingValue(), args&: newOuterDimsPermVec);
360
361 return success();
362 }
363
364private:
365 ControlFoldIntoPackUnpackFn controlFn;
366};
367
368/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
369/// semantics.
370struct FoldConsumerPackWithProducerLinalgTransposeOp
371 : public OpRewritePattern<PackOp> {
372
373public:
374 FoldConsumerPackWithProducerLinalgTransposeOp(
375 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
376 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
377
378 LogicalResult matchAndRewrite(PackOp packOp,
379 PatternRewriter &rewriter) const override {
380 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
381 if (!linalgOp)
382 return failure();
383
384 // User controlled folding function.
385 if (controlFn && !controlFn(&packOp.getSourceMutable()))
386 return failure();
387
388 FailureOr<SmallVector<int64_t>> maybePerm =
389 getTransposeOpPermutation(linalgOp);
390 if (failed(Result: maybePerm))
391 return failure();
392
393 auto transposePermutation = maybePerm.value();
394 auto outerDimsPerm = packOp.getOuterDimsPerm();
395 auto innerDimsPos = packOp.getInnerDimsPos();
396 SmallVector<int64_t> newInnerDimsPosVec;
397 SmallVector<int64_t> newOuterDimsPermVec =
398 llvm::to_vector(Range&: transposePermutation);
399
400 if (!outerDimsPerm.empty())
401 applyPermutationToVector(inVec&: newOuterDimsPermVec, permutation: outerDimsPerm);
402
403 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
404 // permutation rank won't necessarily be equal in all cases.
405 for (auto dim : innerDimsPos)
406 newInnerDimsPosVec.push_back(Elt: transposePermutation[dim]);
407
408 Value output = packOp.createDestinationTensor(
409 b&: rewriter, loc: packOp.getLoc(), source: linalgOp->getOperand(idx: 0),
410 innerTileSizes: packOp.getMixedTiles(), innerDimsPos: newInnerDimsPosVec, outerDimsPerm: newOuterDimsPermVec);
411
412 rewriter.replaceOpWithNewOp<PackOp>(
413 op: packOp, args: linalgOp->getOperand(idx: 0), args&: output, args&: newInnerDimsPosVec,
414 args: packOp.getMixedTiles(), args: packOp.getPaddingValue(), args&: newOuterDimsPermVec);
415
416 return success();
417 }
418
419private:
420 ControlFoldIntoPackUnpackFn controlFn;
421};
422
423/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
424/// transpose semantics.
425struct FoldProducerUnPackWithConsumerLinalgTransposeOp
426 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
427
428public:
429 FoldProducerUnPackWithConsumerLinalgTransposeOp(
430 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
431 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
432 controlFn(std::move(controlFn)) {}
433
434 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
435 PatternRewriter &rewriter) const override {
436 auto unPackOp = linalgOp->getOperand(idx: 0).getDefiningOp<UnPackOp>();
437
438 if (!unPackOp)
439 return failure();
440
441 // User controlled folding function.
442 if (controlFn && !controlFn(&linalgOp->getOpOperand(idx: 0)))
443 return failure();
444
445 FailureOr<SmallVector<int64_t>> maybePerm =
446 getTransposeOpPermutation(linalgOp);
447 if (failed(Result: maybePerm))
448 return failure();
449
450 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
451 auto innerDimsPos = unPackOp.getInnerDimsPos();
452 SmallVector<int64_t> newInnerDimsPosVec;
453 SmallVector<int64_t> newOuterDimsPermVec =
454 invertPermutationVector(permutation: maybePerm.value());
455
456 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
457 // permutation rank won't necessarily be equal in all cases.
458 for (auto dim : innerDimsPos)
459 newInnerDimsPosVec.push_back(Elt: newOuterDimsPermVec[dim]);
460
461 if (!outerDimsPerm.empty())
462 applyPermutationToVector(inVec&: newOuterDimsPermVec, permutation: outerDimsPerm);
463
464 // Reuse the destination of the transpose op.
465 rewriter.replaceOpWithNewOp<UnPackOp>(
466 op: linalgOp, args: unPackOp.getSource(), args: linalgOp.getDpsInits()[0],
467 args&: newInnerDimsPosVec, args: unPackOp.getMixedTiles(), args&: newOuterDimsPermVec);
468
469 return success();
470 }
471
472private:
473 ControlFoldIntoPackUnpackFn controlFn;
474};
475
476/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
477/// transpose semantics.
478struct FoldConsumerUnPackWithProducerLinalgTransposeOp
479 : public OpRewritePattern<UnPackOp> {
480 using OpRewritePattern<UnPackOp>::OpRewritePattern;
481
482public:
483 FoldConsumerUnPackWithProducerLinalgTransposeOp(
484 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
485 : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
486
487 LogicalResult matchAndRewrite(UnPackOp unPackOp,
488 PatternRewriter &rewriter) const override {
489 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
490 if (!linalgOp)
491 return failure();
492
493 // User controlled folding function.
494 if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
495 return failure();
496
497 FailureOr<SmallVector<int64_t>> maybePerm =
498 getTransposeOpPermutation(linalgOp);
499 if (failed(Result: maybePerm))
500 return failure();
501
502 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
503 if (failed(Result: reifyResultShapes(b&: rewriter, op: unPackOp, reifiedReturnShapes&: unpackOpResultDims))) {
504 return failure();
505 }
506
507 SmallVector<int64_t> inverseTransposePerm =
508 invertPermutationVector(permutation: maybePerm.value());
509 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
510 auto innerDimsPos = unPackOp.getInnerDimsPos();
511 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
512 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
513 SmallVector<int64_t> newOuterDimsPermVec;
514 SmallVector<int64_t> newInnerDimsPosVec;
515 SmallVector<OpFoldResult> newMixedInnerTilesVec;
516 if (!checkAndPermute(permutation: inverseTransposePerm, inVec: outerDimsPerm,
517 resVec&: newOuterDimsPermVec, rank: destRank))
518 return rewriter.notifyMatchFailure(
519 arg&: unPackOp,
520 msg: "Cannot fold in tensor.unpack if a tile dimension was transposed "
521 "with a non-tile dimension in linalg.transpose.");
522
523 // Process transpose operation for tiled inner dimensions
524 for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
525 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
526 newMixedInnerTilesVec.push_back(Elt: mixedInnerTilesVec[remappedPosition]);
527 newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]);
528 }
529
530 auto elemType =
531 cast<ShapedType>(Val: unPackOp->getResultTypes()[0]).getElementType();
532 Value output = rewriter.create<tensor::EmptyOp>(
533 location: unPackOp->getLoc(), args&: unpackOpResultDims[0], args&: elemType);
534
535 rewriter.replaceOpWithNewOp<UnPackOp>(
536 op: unPackOp, args: linalgOp->getOperand(idx: 0), args&: output, args&: newInnerDimsPosVec,
537 args&: newMixedInnerTilesVec, args&: newOuterDimsPermVec);
538
539 return success();
540 }
541
542private:
543 ControlFoldIntoPackUnpackFn controlFn;
544};
545
546/// tensor.empty does not define any tensor contents, so an unpadded pack
547/// can be folded away.
548struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
549 using OpRewritePattern<PackOp>::OpRewritePattern;
550
551 LogicalResult matchAndRewrite(PackOp packOp,
552 PatternRewriter &rewriter) const override {
553 // Check for tensor.empty source.
554 auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
555 if (!emptyOp)
556 return failure();
557
558 // Check for padding.
559 // Packing with padding cannot be simply removed.
560 if (packOp.getPaddingValue())
561 return rewriter.notifyMatchFailure(arg&: packOp, msg: "expects no padding value");
562
563 // Replace the pack directly with its destination.
564 rewriter.replaceOp(op: packOp, newValues: packOp.getDest());
565
566 return success();
567 }
568};
569
570/// tensor.empty does not define any tensor contents, so an unpack
571/// can be folded away.
572struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
573 using OpRewritePattern<UnPackOp>::OpRewritePattern;
574
575 LogicalResult matchAndRewrite(UnPackOp unPackOp,
576 PatternRewriter &rewriter) const override {
577 // Check for tensor.empty source.
578 auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
579 if (!emptyOp)
580 return failure();
581
582 // Replace the unpack directly with its destination.
583 rewriter.replaceOp(op: unPackOp, newValues: unPackOp.getDest());
584
585 return success();
586 }
587};
588
589} // namespace
590
591void populateFoldIntoPackAndUnpackPatterns(
592 RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
593 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
594 FoldProducerPackWithConsumerLinalgTransposeOp,
595 FoldConsumerPackWithProducerLinalgTransposeOp,
596 FoldConsumerUnPackWithProducerLinalgTransposeOp,
597 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
598 arg: patterns.getContext(), args: controlFn);
599}
600
601void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
602 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
603 arg: patterns.getContext());
604}
605
606void populateFoldPackUnpackIntoTensorEmptyPatterns(
607 RewritePatternSet &patterns) {
608 patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
609 arg: patterns.getContext());
610}
611
612} // namespace linalg
613} // namespace mlir
614

source code of mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp