1 | //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
10 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
11 | #include "mlir/IR/PatternMatch.h" |
12 | #include "llvm/Support/Debug.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::tensor; |
16 | |
17 | namespace { |
18 | /// Fold expand_shape(extract_slice) ops that cancel itself out. |
19 | struct FoldExpandOfRankReducingExtract |
20 | : public OpRewritePattern<ExpandShapeOp> { |
21 | using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; |
22 | |
23 | LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, |
24 | PatternRewriter &rewriter) const override { |
25 | RankedTensorType resultType = expandShapeOp.getResultType(); |
26 | auto = |
27 | expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>(); |
28 | if (!extractSliceOp) |
29 | return failure(); |
30 | RankedTensorType srcType = extractSliceOp.getSourceType(); |
31 | |
32 | // Only cases where the ExpandShapeOp can be folded away entirely are |
33 | // supported. Moreover, only simple cases where the resulting ExtractSliceOp |
34 | // has no rank-reduction anymore are supported at the moment. |
35 | RankedTensorType = ExtractSliceOp::inferResultType( |
36 | srcType, extractSliceOp.getStaticOffsets(), |
37 | extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); |
38 | if (nonReducingExtractType != resultType) |
39 | return failure(); |
40 | |
41 | SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
42 | SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
43 | SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
44 | rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
45 | expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, |
46 | mixedStrides); |
47 | return success(); |
48 | } |
49 | }; |
50 | |
51 | /// Fold insert_slice(collapse_shape) ops that cancel itself out. |
52 | template <typename OpTy> |
53 | struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> { |
54 | using OpRewritePattern<OpTy>::OpRewritePattern; |
55 | |
56 | LogicalResult matchAndRewrite(OpTy insertSliceOp, |
57 | PatternRewriter &rewriter) const override { |
58 | auto collapseShapeOp = |
59 | insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>(); |
60 | if (!collapseShapeOp) |
61 | return failure(); |
62 | RankedTensorType srcType = collapseShapeOp.getSrcType(); |
63 | |
64 | // Only cases where the CollapseShapeOp can be folded away entirely are |
65 | // supported. Moreover, only simple cases where the resulting InsertSliceOp |
66 | // has no rank-reduction anymore are supported at the moment. |
67 | RankedTensorType nonReducingInsertType = |
68 | RankedTensorType::get(insertSliceOp.getStaticSizes(), |
69 | insertSliceOp.getDestType().getElementType()); |
70 | if (nonReducingInsertType != srcType) |
71 | return failure(); |
72 | |
73 | SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); |
74 | SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); |
75 | SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); |
76 | rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(), |
77 | insertSliceOp.getDest(), mixedOffsets, |
78 | mixedSizes, mixedStrides); |
79 | return success(); |
80 | } |
81 | }; |
82 | } // namespace |
83 | |
84 | void mlir::tensor::populateReassociativeReshapeFoldingPatterns( |
85 | RewritePatternSet &patterns) { |
86 | patterns.add<FoldExpandOfRankReducingExtract, |
87 | FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>, |
88 | FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>( |
89 | patterns.getContext()); |
90 | } |
91 | |