1//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9#include "mlir/Dialect/Tensor/IR/Tensor.h"
10#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11#include "mlir/IR/PatternMatch.h"
12#include "llvm/Support/Debug.h"
13
14using namespace mlir;
15using namespace mlir::tensor;
16
17namespace {
18
19template <typename ReshapeOp>
20struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
21 FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
22 bool foldSingleUseOnly = false)
23 : OpRewritePattern<ReshapeOp>(ctx, benefit),
24 foldSingleUseOnly(foldSingleUseOnly) {}
25
26 LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
27 PatternRewriter &rewriter) const override {
28 // Check for tensor.empty source.
29 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
30 if (!emptyOp)
31 return failure();
32
33 // Check for single use.
34 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
35 return failure();
36
37 // Reify result shape.
38 Location loc = reshapeOp.getLoc();
39 ReifiedRankedShapedTypeDims resultShapes;
40 if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
41 !llvm::hasSingleElement(C&: resultShapes))
42 return failure();
43
44 // Create new tensor.empty op.
45 // TODO: Do not drop tensor type encoding.
46 Value emptyTensor = rewriter.create<EmptyOp>(
47 loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48 if (emptyTensor.getType() != reshapeOp.getResultType()) {
49 rewriter.replaceOpWithNewOp<tensor::CastOp>(
50 reshapeOp, reshapeOp.getResultType(), emptyTensor);
51 } else {
52 rewriter.replaceOp(reshapeOp, emptyTensor);
53 }
54 return success();
55 }
56
57private:
58 bool foldSingleUseOnly = false;
59};
60
61/// tensor.empty does not define any tensor contents, so a slice of a
62/// tensor.empty can be folded to a smaller tensor.empty.
63struct FoldEmptyTensorWithExtractSliceOp
64 : public OpRewritePattern<ExtractSliceOp> {
65 FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66 PatternBenefit benefit = 1,
67 bool foldSingleUseOnly = false)
68 : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69 foldSingleUseOnly(foldSingleUseOnly) {}
70
71 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72 PatternRewriter &rewriter) const override {
73 // Check for tensor.empty source.
74 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
75 if (!emptyOp)
76 return failure();
77
78 // Check for single use.
79 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
80 return failure();
81
82 // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
83 // its dynamic sizes must be preserved as well as its result type.
84 auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85 sliceOp.getType().getElementType(),
86 sliceOp.getType().getEncoding());
87 rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88 sliceOp.getSizes());
89 return success();
90 }
91
92private:
93 bool foldSingleUseOnly = false;
94};
95
96} // namespace
97
98void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
99 bool foldSingleUseOnly) {
100 patterns.add<FoldEmptyTensorWithExtractSliceOp,
101 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
102 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
103 patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
104}
105

source code of mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp