1//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/IR/Tensor.h"
10#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11#include "mlir/IR/PatternMatch.h"
12#include "llvm/Support/Debug.h"
13
14using namespace mlir;
15using namespace mlir::tensor;
16
17namespace {
18/// Fold expand_shape(extract_slice) ops that cancel itself out.
19struct FoldExpandOfRankReducingExtract
20 : public OpRewritePattern<ExpandShapeOp> {
21 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
22
23 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
24 PatternRewriter &rewriter) const override {
25 RankedTensorType resultType = expandShapeOp.getResultType();
26 auto extractSliceOp =
27 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
28 if (!extractSliceOp)
29 return failure();
30 RankedTensorType srcType = extractSliceOp.getSourceType();
31
32 // Only cases where the ExpandShapeOp can be folded away entirely are
33 // supported. Moreover, only simple cases where the resulting ExtractSliceOp
34 // has no rank-reduction anymore are supported at the moment.
35 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
36 srcType, extractSliceOp.getStaticOffsets(),
37 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
38 if (nonReducingExtractType != resultType)
39 return failure();
40
41 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
42 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
43 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
44 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
45 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
46 mixedStrides);
47 return success();
48 }
49};
50
51/// Fold insert_slice(collapse_shape) ops that cancel itself out.
52template <typename OpTy>
53struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
54 using OpRewritePattern<OpTy>::OpRewritePattern;
55
56 LogicalResult matchAndRewrite(OpTy insertSliceOp,
57 PatternRewriter &rewriter) const override {
58 auto collapseShapeOp =
59 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
60 if (!collapseShapeOp)
61 return failure();
62 RankedTensorType srcType = collapseShapeOp.getSrcType();
63
64 // Only cases where the CollapseShapeOp can be folded away entirely are
65 // supported. Moreover, only simple cases where the resulting InsertSliceOp
66 // has no rank-reduction anymore are supported at the moment.
67 RankedTensorType nonReducingInsertType =
68 RankedTensorType::get(insertSliceOp.getStaticSizes(),
69 insertSliceOp.getDestType().getElementType());
70 if (nonReducingInsertType != srcType)
71 return failure();
72
73 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
74 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
75 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
76 rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
77 insertSliceOp.getDest(), mixedOffsets,
78 mixedSizes, mixedStrides);
79 return success();
80 }
81};
82} // namespace
83
84void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
85 RewritePatternSet &patterns) {
86 patterns.add<FoldExpandOfRankReducingExtract,
87 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
88 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>(
89 patterns.getContext());
90}
91

source code of mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp