1//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9#include "mlir/Dialect/Tensor/IR/Tensor.h"
10#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11#include "mlir/IR/PatternMatch.h"
12
13using namespace mlir;
14using namespace mlir::tensor;
15
16namespace {
17
18template <typename ReshapeOp>
19struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
20 FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
21 bool foldSingleUseOnly = false)
22 : OpRewritePattern<ReshapeOp>(ctx, benefit),
23 foldSingleUseOnly(foldSingleUseOnly) {}
24
25 LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
26 PatternRewriter &rewriter) const override {
27 // Check for tensor.empty source.
28 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
29 if (!emptyOp)
30 return failure();
31
32 // Check for single use.
33 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
34 return failure();
35
36 // Reify result shape.
37 Location loc = reshapeOp.getLoc();
38 ReifiedRankedShapedTypeDims resultShapes;
39 if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
40 !llvm::hasSingleElement(C&: resultShapes))
41 return failure();
42
43 // Create new tensor.empty op.
44 // TODO: Do not drop tensor type encoding.
45 Value emptyTensor = rewriter.create<EmptyOp>(
46 loc, resultShapes[0], reshapeOp.getResultType().getElementType());
47 if (emptyTensor.getType() != reshapeOp.getResultType()) {
48 rewriter.replaceOpWithNewOp<tensor::CastOp>(
49 reshapeOp, reshapeOp.getResultType(), emptyTensor);
50 } else {
51 rewriter.replaceOp(reshapeOp, emptyTensor);
52 }
53 return success();
54 }
55
56private:
57 bool foldSingleUseOnly = false;
58};
59
60/// tensor.empty does not define any tensor contents, so a slice of a
61/// tensor.empty can be folded to a smaller tensor.empty.
62struct FoldEmptyTensorWithExtractSliceOp
63 : public OpRewritePattern<ExtractSliceOp> {
64 FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
65 PatternBenefit benefit = 1,
66 bool foldSingleUseOnly = false)
67 : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
68 foldSingleUseOnly(foldSingleUseOnly) {}
69
70 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
71 PatternRewriter &rewriter) const override {
72 // Check for tensor.empty source.
73 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
74 if (!emptyOp)
75 return failure();
76
77 // Check for single use.
78 if (foldSingleUseOnly && !llvm::hasSingleElement(C: emptyOp->getUses()))
79 return failure();
80
81 // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
82 // its dynamic sizes must be preserved as well as its result type.
83 auto tensorType = RankedTensorType::get(shape: sliceOp.getType().getShape(),
84 elementType: sliceOp.getType().getElementType(),
85 encoding: sliceOp.getType().getEncoding());
86 rewriter.replaceOpWithNewOp<EmptyOp>(op: sliceOp, args&: tensorType,
87 args: sliceOp.getSizes());
88 return success();
89 }
90
91private:
92 bool foldSingleUseOnly = false;
93};
94
95// Fold concat operation where all the operands are empty.
96struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
97 using OpRewritePattern<ConcatOp>::OpRewritePattern;
98
99 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
100 PatternRewriter &rewriter) const override {
101 auto concatOperands = concatOp.getInputs();
102 if (concatOperands.empty()) {
103 return failure();
104 }
105 auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
106 if (!firstEmptyOp) {
107 return failure();
108 }
109 auto isDefinedByEmptyOp = [](Value v) -> bool {
110 return v.getDefiningOp<tensor::EmptyOp>();
111 };
112 if (!llvm::all_of(Range: concatOperands.drop_front(), P: isDefinedByEmptyOp)) {
113 return rewriter.notifyMatchFailure(
114 arg&: concatOp, msg: "not all operands are defined by an empty op");
115 }
116 SmallVector<SmallVector<OpFoldResult>> resultShape;
117 if (failed(Result: concatOp.reifyResultShapes(builder&: rewriter, reifiedReturnShapes&: resultShape))) {
118 return rewriter.notifyMatchFailure(arg&: concatOp,
119 msg: "failed to get result shape");
120 }
121 rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
122 op: concatOp, args&: resultShape[0], args: concatOp.getResultType().getElementType());
123 return success();
124 }
125};
126
127} // namespace
128
129void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
130 bool foldSingleUseOnly) {
131 patterns.add<FoldEmptyTensorWithExtractSliceOp,
132 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
133 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
134 arg: patterns.getContext(), /*benefit=*/args: 1, args&: foldSingleUseOnly);
135 patterns.add<FoldConcatsOfEmpty>(arg: patterns.getContext(),
136 /*benefit=*/args: 1);
137}
138

source code of mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp