1 | //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
10 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
11 | #include "mlir/IR/PatternMatch.h" |
12 | #include "llvm/Support/Debug.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::tensor; |
16 | |
17 | namespace { |
18 | |
19 | template <typename ReshapeOp> |
20 | struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> { |
21 | FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1, |
22 | bool foldSingleUseOnly = false) |
23 | : OpRewritePattern<ReshapeOp>(ctx, benefit), |
24 | foldSingleUseOnly(foldSingleUseOnly) {} |
25 | |
26 | LogicalResult matchAndRewrite(ReshapeOp reshapeOp, |
27 | PatternRewriter &rewriter) const override { |
28 | // Check for tensor.empty source. |
29 | auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>(); |
30 | if (!emptyOp) |
31 | return failure(); |
32 | |
33 | // Check for single use. |
34 | if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) |
35 | return failure(); |
36 | |
37 | // Reify result shape. |
38 | Location loc = reshapeOp.getLoc(); |
39 | ReifiedRankedShapedTypeDims resultShapes; |
40 | if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || |
41 | !llvm::hasSingleElement(C&: resultShapes)) |
42 | return failure(); |
43 | |
44 | // Create new tensor.empty op. |
45 | // TODO: Do not drop tensor type encoding. |
46 | Value emptyTensor = rewriter.create<EmptyOp>( |
47 | loc, resultShapes[0], reshapeOp.getResultType().getElementType()); |
48 | if (emptyTensor.getType() != reshapeOp.getResultType()) { |
49 | rewriter.replaceOpWithNewOp<tensor::CastOp>( |
50 | reshapeOp, reshapeOp.getResultType(), emptyTensor); |
51 | } else { |
52 | rewriter.replaceOp(reshapeOp, emptyTensor); |
53 | } |
54 | return success(); |
55 | } |
56 | |
57 | private: |
58 | bool foldSingleUseOnly = false; |
59 | }; |
60 | |
61 | /// tensor.empty does not define any tensor contents, so a slice of a |
62 | /// tensor.empty can be folded to a smaller tensor.empty. |
63 | struct |
64 | : public OpRewritePattern<ExtractSliceOp> { |
65 | (MLIRContext *ctx, |
66 | PatternBenefit benefit = 1, |
67 | bool foldSingleUseOnly = false) |
68 | : OpRewritePattern<ExtractSliceOp>(ctx, benefit), |
69 | foldSingleUseOnly(foldSingleUseOnly) {} |
70 | |
71 | LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, |
72 | PatternRewriter &rewriter) const override { |
73 | // Check for tensor.empty source. |
74 | auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>(); |
75 | if (!emptyOp) |
76 | return failure(); |
77 | |
78 | // Check for single use. |
79 | if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) |
80 | return failure(); |
81 | |
82 | // Create new tensor.empty op. tensor.extract_slice may be rank-reducing; |
83 | // its dynamic sizes must be preserved as well as its result type. |
84 | auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), |
85 | sliceOp.getType().getElementType(), |
86 | sliceOp.getType().getEncoding()); |
87 | rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType, |
88 | sliceOp.getSizes()); |
89 | return success(); |
90 | } |
91 | |
92 | private: |
93 | bool = false; |
94 | }; |
95 | |
96 | } // namespace |
97 | |
98 | void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, |
99 | bool foldSingleUseOnly) { |
100 | patterns.add<FoldEmptyTensorWithExtractSliceOp, |
101 | FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>, |
102 | FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>( |
103 | patterns.getContext(), /*benefit=*/1, foldSingleUseOnly); |
104 | } |
105 | |