1 | //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
10 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
11 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
12 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | |
15 | namespace mlir { |
16 | namespace linalg { |
17 | namespace { |
18 | |
19 | /// Returns the number of shape sizes that is either dynamic or greater than 1. |
20 | static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) { |
21 | return llvm::count_if( |
22 | Range&: shape, P: [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; }); |
23 | } |
24 | |
25 | /// Returns success() if there is only 1 dimension size in non-packed domain |
26 | /// being greater than 1 and packing only happens on the dimension. |
27 | /// Note: this method should only be used by pack/unpack to reshape conversion. |
28 | /// It assumes that non-unit inner tile size must be used by the non-unit |
29 | /// dimension. |
30 | static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op, |
31 | ArrayRef<int64_t> srcShape, |
32 | ArrayRef<int64_t> innerPackTileSize) { |
33 | if (getNumGtOneDims(shape: srcShape) > 1) { |
34 | return rewriter.notifyMatchFailure( |
35 | arg&: op, msg: "expects non-packed domain to have at most one non-unit dims" ); |
36 | } |
37 | // Non-unit inner tile size must be used by the non-unit dimension. If not, it |
38 | // will faill on getting reassociation maps. |
39 | if (getNumGtOneDims(shape: innerPackTileSize) > 1) { |
40 | return rewriter.notifyMatchFailure( |
41 | arg&: op, msg: "expects at most one non-unit inner tiles" ); |
42 | } |
43 | return success(); |
44 | } |
45 | |
46 | // If the `linalgOp` represents a transpose, return the permutation vector for |
47 | // the transpose. Otherwise, return failure. |
48 | static FailureOr<SmallVector<int64_t>> |
49 | getTransposeOpPermutation(linalg::LinalgOp linalgOp) { |
50 | if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation())) |
51 | return SmallVector<int64_t>(transposeOp.getPermutation()); |
52 | if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) |
53 | return failure(); |
54 | |
55 | if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) |
56 | return failure(); |
57 | auto mapRange = linalgOp.getIndexingMapsArray(); |
58 | if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() || |
59 | mapRange.front() == mapRange.back()) { |
60 | return failure(); |
61 | } |
62 | if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations())) |
63 | return failure(); |
64 | AffineMap outMap = mapRange.back(); |
65 | AffineMap inMap = mapRange.front(); |
66 | // To get the permutation, look at each output index and find which |
67 | // dimension in the input we're reading from for that index. |
68 | return llvm::map_to_vector(C: outMap.getResults(), |
69 | F: [&](AffineExpr expr) -> int64_t { |
70 | return *inMap.getResultPosition(input: expr); |
71 | }); |
72 | } |
73 | |
74 | /// Packing one-dimensional tensor can be expressed as an expand shape op. |
75 | struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> { |
76 | using OpRewritePattern<PackOp>::OpRewritePattern; |
77 | |
78 | FailureOr<Value> |
79 | insertExpand(RewriterBase &rewriter, Location loc, Value operand, |
80 | Type newOperandType, |
81 | ArrayRef<ReassociationIndices> reassociation) const { |
82 | if (operand.getType() == newOperandType) |
83 | return operand; |
84 | return rewriter |
85 | .create<tensor::ExpandShapeOp>(loc, newOperandType, operand, |
86 | reassociation) |
87 | .getResult(); |
88 | } |
89 | |
90 | /// Returns success() if it is only packing on the innermost dimension. |
91 | LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter, |
92 | PackOp packOp) const { |
93 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
94 | if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { |
95 | return rewriter.notifyMatchFailure( |
96 | packOp, |
97 | "expects outer_dims_perm is empty or an identity permutation" ); |
98 | } |
99 | |
100 | int64_t srcRank = packOp.getSourceRank(); |
101 | ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos(); |
102 | if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) { |
103 | return rewriter.notifyMatchFailure( |
104 | packOp, "expects packing at the innermost dimension" ); |
105 | } |
106 | return success(); |
107 | } |
108 | |
109 | LogicalResult matchAndRewrite(PackOp packOp, |
110 | PatternRewriter &rewriter) const override { |
111 | if (packOp.getPaddingValue()) |
112 | return rewriter.notifyMatchFailure(packOp, "expects no padding value" ); |
113 | |
114 | RankedTensorType sourceType = packOp.getSourceType(); |
115 | if (failed(isPackOnInnerMostDim(rewriter, packOp)) && |
116 | failed(isPackOn1D(rewriter, packOp, sourceType.getShape(), |
117 | packOp.getStaticTiles())) && |
118 | !packOp.isLikePad()) { |
119 | return failure(); |
120 | } |
121 | |
122 | RankedTensorType destType = packOp.getDestType(); |
123 | auto reassociation = |
124 | getReassociationIndicesForReshape(sourceType, destType); |
125 | if (!reassociation) |
126 | return failure(); |
127 | FailureOr<Value> expanded = |
128 | insertExpand(rewriter, loc: packOp.getLoc(), operand: packOp.getSource(), newOperandType: destType, |
129 | reassociation: *reassociation); |
130 | if (failed(Result: expanded)) { |
131 | return rewriter.notifyMatchFailure( |
132 | packOp, "unable to expand source of tensor.pack" ); |
133 | } |
134 | rewriter.replaceOp(packOp, *expanded); |
135 | return success(); |
136 | } |
137 | }; |
138 | |
139 | struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> { |
140 | using OpRewritePattern<UnPackOp>::OpRewritePattern; |
141 | |
142 | Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand, |
143 | Type newOperandType, ArrayAttr reassociation) const { |
144 | if (operand.getType() == newOperandType) |
145 | return operand; |
146 | return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType, |
147 | operand, reassociation); |
148 | } |
149 | |
150 | /// Returns success() if it is unpacking on the innermost dimension. |
151 | LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter, |
152 | UnPackOp unpackOp) const { |
153 | auto outerDimsPerm = unpackOp.getOuterDimsPerm(); |
154 | if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { |
155 | return rewriter.notifyMatchFailure( |
156 | unpackOp, |
157 | "expects outer_dims_perm is empty or an identity permutation" ); |
158 | } |
159 | |
160 | RankedTensorType sourceType = unpackOp.getSourceType(); |
161 | RankedTensorType destType = unpackOp.getDestType(); |
162 | if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) |
163 | return rewriter.notifyMatchFailure(unpackOp, "expects static shapes" ); |
164 | |
165 | ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos(); |
166 | if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) { |
167 | return rewriter.notifyMatchFailure( |
168 | unpackOp, "expects unpacking on the innermost dimension" ); |
169 | } |
170 | |
171 | return success(); |
172 | } |
173 | |
174 | LogicalResult matchAndRewrite(UnPackOp unpackOp, |
175 | PatternRewriter &rewriter) const override { |
176 | RankedTensorType destType = unpackOp.getDestType(); |
177 | if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) && |
178 | failed(isPackOn1D(rewriter, unpackOp, destType.getShape(), |
179 | unpackOp.getStaticTiles())) && |
180 | !unpackOp.isLikeUnPad()) { |
181 | return failure(); |
182 | } |
183 | |
184 | RankedTensorType sourceType = unpackOp.getSourceType(); |
185 | auto reassociation = |
186 | getReassociationIndicesForReshape(sourceType, destType); |
187 | if (!reassociation) |
188 | return failure(); |
189 | Value collapsed = insertCollapse( |
190 | rewriter, loc: unpackOp.getLoc(), operand: unpackOp.getSource(), newOperandType: destType, |
191 | reassociation: getReassociationIndicesAttribute(rewriter, *reassociation)); |
192 | rewriter.replaceOp(unpackOp, collapsed); |
193 | return success(); |
194 | } |
195 | }; |
196 | |
197 | /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and |
198 | /// the pad op has zero low paddings, or if `pack` has no padding values. |
199 | struct FoldPadWithPackOp : public OpRewritePattern<PackOp> { |
200 | using OpRewritePattern<PackOp>::OpRewritePattern; |
201 | |
202 | LogicalResult matchAndRewrite(PackOp packOp, |
203 | PatternRewriter &rewriter) const override { |
204 | auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); |
205 | |
206 | if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) |
207 | return failure(); |
208 | |
209 | Value constantPaddingValue = padOp.getConstantPaddingValue(); |
210 | if (!constantPaddingValue) |
211 | return failure(); |
212 | |
213 | if (auto paddingValue = packOp.getPaddingValue()) |
214 | if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) |
215 | return failure(); |
216 | |
217 | rewriter.replaceOpWithNewOp<PackOp>( |
218 | packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), |
219 | packOp.getMixedTiles(), constantPaddingValue, |
220 | packOp.getOuterDimsPerm()); |
221 | return success(); |
222 | } |
223 | }; |
224 | |
225 | /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already |
226 | /// has extract_slice semantics. |
227 | struct |
228 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
229 | using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
230 | |
231 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
232 | PatternRewriter &rewriter) const override { |
233 | auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>(); |
234 | if (!unpackOp) |
235 | return failure(); |
236 | |
237 | if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { |
238 | return rewriter.notifyMatchFailure( |
239 | sliceOp, "rank-reduced folding is not supported" ); |
240 | } |
241 | |
242 | // Check all offsets are zeros, and all strides are ones. |
243 | if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || |
244 | !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { |
245 | return rewriter.notifyMatchFailure( |
246 | sliceOp, "expects offsets to be 0s and strides to be 1s" ); |
247 | } |
248 | |
249 | // Create a new empty output tensor. |
250 | Type elementType = unpackOp.getDestType().getElementType(); |
251 | Value output = rewriter.create<tensor::EmptyOp>( |
252 | sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); |
253 | rewriter.replaceOpWithNewOp<UnPackOp>( |
254 | sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), |
255 | unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); |
256 | return success(); |
257 | } |
258 | }; |
259 | |
260 | // Applies 'permutation' on 'inVec' and stores the result in resVec. |
261 | // 'inVec' may be empty, in that case it's one-to-one mapping with permutation. |
262 | // `rank` sets the boundary for permutation i.e., the permutation dim can't be |
263 | // greater than the rank specified. If it's so then return false. |
264 | // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in |
265 | // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is |
266 | // not allowed since `3` exceeds the value of the rank in the given range. |
267 | static bool checkAndPermute(ArrayRef<int64_t> permutation, |
268 | ArrayRef<int64_t> inVec, |
269 | SmallVectorImpl<int64_t> &resVec, int64_t rank) { |
270 | |
271 | for (unsigned int i = 0; i < rank; ++i) { |
272 | int64_t remappedPosition = permutation[i]; |
273 | if (remappedPosition >= rank) |
274 | return false; |
275 | if (!inVec.empty()) |
276 | remappedPosition = inVec[remappedPosition]; |
277 | resVec.push_back(Elt: remappedPosition); |
278 | } |
279 | |
280 | return true; |
281 | } |
282 | |
283 | /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose |
284 | /// semantics. |
285 | struct FoldProducerPackWithConsumerLinalgTransposeOp |
286 | : public OpInterfaceRewritePattern<linalg::LinalgOp> { |
287 | using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern; |
288 | |
289 | LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, |
290 | PatternRewriter &rewriter) const override { |
291 | auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>(); |
292 | |
293 | if (!packOp) |
294 | return failure(); |
295 | |
296 | FailureOr<SmallVector<int64_t>> maybePerm = |
297 | getTransposeOpPermutation(linalgOp); |
298 | if (failed(Result: maybePerm)) |
299 | return failure(); |
300 | |
301 | auto innerDimsPos = packOp.getInnerDimsPos(); |
302 | auto mixedInnerTiles = packOp.getMixedTiles(); |
303 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
304 | auto transposePerm = maybePerm.value(); |
305 | SmallVector<int64_t> newOuterDimsPermVec; |
306 | SmallVector<int64_t> newInnerDimsPosVec; |
307 | SmallVector<OpFoldResult> newMixedInnerTilesVec; |
308 | int64_t srcRank = packOp.getSourceRank(); |
309 | |
310 | if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec, |
311 | srcRank)) |
312 | return rewriter.notifyMatchFailure( |
313 | linalgOp, |
314 | "Cannot fold in tensor.pack if a tile dimension was transposed " |
315 | "with a non-tile dimension in linalg.transpose." ); |
316 | |
317 | // Process transpose operation for tiled inner dimensions |
318 | for (unsigned int i = srcRank; i < transposePerm.size(); ++i) { |
319 | int64_t remappedPosition = transposePerm[i] - srcRank; |
320 | newMixedInnerTilesVec.push_back(Elt: mixedInnerTiles[remappedPosition]); |
321 | newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]); |
322 | } |
323 | |
324 | Value output = packOp.createDestinationTensor( |
325 | rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec, |
326 | newInnerDimsPosVec, newOuterDimsPermVec); |
327 | |
328 | rewriter.replaceOpWithNewOp<PackOp>( |
329 | linalgOp, packOp.getSource(), output, newInnerDimsPosVec, |
330 | newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec); |
331 | |
332 | return success(); |
333 | } |
334 | }; |
335 | |
336 | /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose |
337 | /// semantics. |
338 | struct FoldConsumerPackWithProducerLinalgTransposeOp |
339 | : public OpRewritePattern<PackOp> { |
340 | using OpRewritePattern<PackOp>::OpRewritePattern; |
341 | |
342 | LogicalResult matchAndRewrite(PackOp packOp, |
343 | PatternRewriter &rewriter) const override { |
344 | auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>(); |
345 | if (!linalgOp) |
346 | return failure(); |
347 | |
348 | FailureOr<SmallVector<int64_t>> maybePerm = |
349 | getTransposeOpPermutation(linalgOp); |
350 | if (failed(Result: maybePerm)) |
351 | return failure(); |
352 | |
353 | auto transposePermutation = maybePerm.value(); |
354 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
355 | auto innerDimsPos = packOp.getInnerDimsPos(); |
356 | SmallVector<int64_t> newInnerDimsPosVec; |
357 | SmallVector<int64_t> newOuterDimsPermVec = |
358 | llvm::to_vector(transposePermutation); |
359 | |
360 | if (!outerDimsPerm.empty()) |
361 | applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); |
362 | |
363 | // Can't use applyPermutationToVector for newInnerDimsPosVec since input and |
364 | // permutation rank won't necessarily be equal in all cases. |
365 | for (auto dim : innerDimsPos) |
366 | newInnerDimsPosVec.push_back(transposePermutation[dim]); |
367 | |
368 | Value output = packOp.createDestinationTensor( |
369 | rewriter, packOp.getLoc(), linalgOp->getOperand(0), |
370 | packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec); |
371 | |
372 | rewriter.replaceOpWithNewOp<PackOp>( |
373 | packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, |
374 | packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec); |
375 | |
376 | return success(); |
377 | } |
378 | }; |
379 | |
380 | /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has |
381 | /// transpose semantics. |
382 | struct FoldProducerUnPackWithConsumerLinalgTransposeOp |
383 | : public OpInterfaceRewritePattern<linalg::LinalgOp> { |
384 | using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern; |
385 | |
386 | LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, |
387 | PatternRewriter &rewriter) const override { |
388 | auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>(); |
389 | |
390 | if (!unPackOp) |
391 | return failure(); |
392 | |
393 | FailureOr<SmallVector<int64_t>> maybePerm = |
394 | getTransposeOpPermutation(linalgOp); |
395 | if (failed(Result: maybePerm)) |
396 | return failure(); |
397 | |
398 | auto outerDimsPerm = unPackOp.getOuterDimsPerm(); |
399 | auto innerDimsPos = unPackOp.getInnerDimsPos(); |
400 | SmallVector<int64_t> newInnerDimsPosVec; |
401 | SmallVector<int64_t> newOuterDimsPermVec = |
402 | invertPermutationVector(permutation: maybePerm.value()); |
403 | |
404 | // Can't use applyPermutationToVector for newInnerDimsPosVec since input and |
405 | // permutation rank won't necessarily be equal in all cases. |
406 | for (auto dim : innerDimsPos) |
407 | newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]); |
408 | |
409 | if (!outerDimsPerm.empty()) |
410 | applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm); |
411 | |
412 | // Reuse the destination of the transpose op. |
413 | rewriter.replaceOpWithNewOp<UnPackOp>( |
414 | linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0], |
415 | newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec); |
416 | |
417 | return success(); |
418 | } |
419 | }; |
420 | |
421 | /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has |
422 | /// transpose semantics. |
423 | struct FoldConsumerUnPackWithProducerLinalgTransposeOp |
424 | : public OpRewritePattern<UnPackOp> { |
425 | using OpRewritePattern<UnPackOp>::OpRewritePattern; |
426 | |
427 | LogicalResult matchAndRewrite(UnPackOp unPackOp, |
428 | PatternRewriter &rewriter) const override { |
429 | auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>(); |
430 | if (!linalgOp) |
431 | return failure(); |
432 | |
433 | FailureOr<SmallVector<int64_t>> maybePerm = |
434 | getTransposeOpPermutation(linalgOp); |
435 | if (failed(Result: maybePerm)) |
436 | return failure(); |
437 | |
438 | SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims; |
439 | if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) { |
440 | return failure(); |
441 | } |
442 | |
443 | SmallVector<int64_t> inverseTransposePerm = |
444 | invertPermutationVector(permutation: maybePerm.value()); |
445 | auto outerDimsPerm = unPackOp.getOuterDimsPerm(); |
446 | auto innerDimsPos = unPackOp.getInnerDimsPos(); |
447 | int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size(); |
448 | auto mixedInnerTilesVec = unPackOp.getMixedTiles(); |
449 | SmallVector<int64_t> newOuterDimsPermVec; |
450 | SmallVector<int64_t> newInnerDimsPosVec; |
451 | SmallVector<OpFoldResult> newMixedInnerTilesVec; |
452 | if (!checkAndPermute(inverseTransposePerm, outerDimsPerm, |
453 | newOuterDimsPermVec, destRank)) |
454 | return rewriter.notifyMatchFailure( |
455 | unPackOp, |
456 | "Cannot fold in tensor.unpack if a tile dimension was transposed " |
457 | "with a non-tile dimension in linalg.transpose." ); |
458 | |
459 | // Process transpose operation for tiled inner dimensions |
460 | for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) { |
461 | int64_t remappedPosition = inverseTransposePerm[i] - destRank; |
462 | newMixedInnerTilesVec.push_back(Elt: mixedInnerTilesVec[remappedPosition]); |
463 | newInnerDimsPosVec.push_back(Elt: innerDimsPos[remappedPosition]); |
464 | } |
465 | |
466 | auto elemType = |
467 | cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType(); |
468 | Value output = rewriter.create<tensor::EmptyOp>( |
469 | unPackOp->getLoc(), unpackOpResultDims[0], elemType); |
470 | |
471 | rewriter.replaceOpWithNewOp<UnPackOp>( |
472 | unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, |
473 | newMixedInnerTilesVec, newOuterDimsPermVec); |
474 | |
475 | return success(); |
476 | } |
477 | }; |
478 | |
479 | /// tensor.empty does not define any tensor contents, so an unpadded pack |
480 | /// can be folded away. |
481 | struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> { |
482 | using OpRewritePattern<PackOp>::OpRewritePattern; |
483 | |
484 | LogicalResult matchAndRewrite(PackOp packOp, |
485 | PatternRewriter &rewriter) const override { |
486 | // Check for tensor.empty source. |
487 | auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>(); |
488 | if (!emptyOp) |
489 | return failure(); |
490 | |
491 | // Check for padding. |
492 | // Packing with padding cannot be simply removed. |
493 | if (packOp.getPaddingValue()) |
494 | return rewriter.notifyMatchFailure(packOp, "expects no padding value" ); |
495 | |
496 | // Replace the pack directly with its destination. |
497 | rewriter.replaceOp(packOp, packOp.getDest()); |
498 | |
499 | return success(); |
500 | } |
501 | }; |
502 | |
503 | /// tensor.empty does not define any tensor contents, so an unpack |
504 | /// can be folded away. |
505 | struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> { |
506 | using OpRewritePattern<UnPackOp>::OpRewritePattern; |
507 | |
508 | LogicalResult matchAndRewrite(UnPackOp unPackOp, |
509 | PatternRewriter &rewriter) const override { |
510 | // Check for tensor.empty source. |
511 | auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>(); |
512 | if (!emptyOp) |
513 | return failure(); |
514 | |
515 | // Replace the unpack directly with its destination. |
516 | rewriter.replaceOp(unPackOp, unPackOp.getDest()); |
517 | |
518 | return success(); |
519 | } |
520 | }; |
521 | |
522 | } // namespace |
523 | |
524 | void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { |
525 | patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp, |
526 | FoldProducerPackWithConsumerLinalgTransposeOp, |
527 | FoldConsumerPackWithProducerLinalgTransposeOp, |
528 | FoldConsumerUnPackWithProducerLinalgTransposeOp, |
529 | FoldProducerUnPackWithConsumerLinalgTransposeOp>( |
530 | arg: patterns.getContext()); |
531 | } |
532 | |
533 | void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { |
534 | patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>( |
535 | arg: patterns.getContext()); |
536 | } |
537 | |
538 | void populateFoldPackUnpackIntoTensorEmptyPatterns( |
539 | RewritePatternSet &patterns) { |
540 | patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>( |
541 | arg: patterns.getContext()); |
542 | } |
543 | |
544 | } // namespace linalg |
545 | } // namespace mlir |
546 | |