1//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9#include "mlir/Dialect/Tensor/IR/Tensor.h"
10#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11#include "mlir/IR/PatternMatch.h"
12#include "llvm/Support/Debug.h"
13
14using namespace mlir;
15using namespace mlir::tensor;
16
17namespace {
18
19template <typename ReshapeOp>
20struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
21 FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
22 bool foldSingleUseOnly = false)
23 : OpRewritePattern<ReshapeOp>(ctx, benefit),
24 foldSingleUseOnly(foldSingleUseOnly) {}
25
26 LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
27 PatternRewriter &rewriter) const override {
28 // Check for tensor.empty source.
29 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
30 if (!emptyOp)
31 return failure();
32
33 // Check for single use.
34 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
35 return failure();
36
37 // Reify result shape.
38 Location loc = reshapeOp.getLoc();
39 ReifiedRankedShapedTypeDims resultShapes;
40 if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
41 !llvm::hasSingleElement(C&: resultShapes))
42 return failure();
43
44 // Create new tensor.empty op.
45 // TODO: Do not drop tensor type encoding.
46 Value emptyTensor = rewriter.create<EmptyOp>(
47 loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48 if (emptyTensor.getType() != reshapeOp.getResultType()) {
49 rewriter.replaceOpWithNewOp<tensor::CastOp>(
50 reshapeOp, reshapeOp.getResultType(), emptyTensor);
51 } else {
52 rewriter.replaceOp(reshapeOp, emptyTensor);
53 }
54 return success();
55 }
56
57private:
58 bool foldSingleUseOnly = false;
59};
60
61/// tensor.empty does not define any tensor contents, so a slice of a
62/// tensor.empty can be folded to a smaller tensor.empty.
63struct FoldEmptyTensorWithExtractSliceOp
64 : public OpRewritePattern<ExtractSliceOp> {
65 FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66 PatternBenefit benefit = 1,
67 bool foldSingleUseOnly = false)
68 : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69 foldSingleUseOnly(foldSingleUseOnly) {}
70
71 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72 PatternRewriter &rewriter) const override {
73 // Check for tensor.empty source.
74 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
75 if (!emptyOp)
76 return failure();
77
78 // Check for single use.
79 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
80 return failure();
81
82 // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
83 // its dynamic sizes must be preserved as well as its result type.
84 auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85 sliceOp.getType().getElementType(),
86 sliceOp.getType().getEncoding());
87 rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88 sliceOp.getSizes());
89 return success();
90 }
91
92private:
93 bool foldSingleUseOnly = false;
94};
95
96// Fold concat operation where all the operands are empty.
97struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
98 using OpRewritePattern<ConcatOp>::OpRewritePattern;
99
100 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
101 PatternRewriter &rewriter) const override {
102 auto concatOperands = concatOp.getInputs();
103 if (concatOperands.empty()) {
104 return failure();
105 }
106 auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
107 if (!firstEmptyOp) {
108 return failure();
109 }
110 auto isDefinedByEmptyOp = [](Value v) -> bool {
111 return v.getDefiningOp<tensor::EmptyOp>();
112 };
113 if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
114 return rewriter.notifyMatchFailure(
115 concatOp, "not all operands are defined by an empty op");
116 }
117 SmallVector<SmallVector<OpFoldResult>> resultShape;
118 if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
119 return rewriter.notifyMatchFailure(concatOp,
120 "failed to get result shape");
121 }
122 rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
123 concatOp, resultShape[0], concatOp.getResultType().getElementType());
124 return success();
125 }
126};
127
128} // namespace
129
130void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
131 bool foldSingleUseOnly) {
132 patterns.add<FoldEmptyTensorWithExtractSliceOp,
133 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
134 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
135 patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
136 patterns.add<FoldConcatsOfEmpty>(arg: patterns.getContext(),
137 /*benefit=*/args: 1);
138}
139

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp