1 | //===- InitTensorToAllocTensor.cpp - Lower tensor.empty to alloc_tensor ---===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
13 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
16 | |
17 | namespace mlir { |
18 | namespace bufferization { |
19 | #define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSORPASS |
20 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
21 | } // namespace bufferization |
22 | } // namespace mlir |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::bufferization; |
26 | using namespace mlir::tensor; |
27 | |
28 | namespace { |
29 | struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> { |
30 | using OpRewritePattern<tensor::EmptyOp>::OpRewritePattern; |
31 | |
32 | LogicalResult matchAndRewrite(tensor::EmptyOp op, |
33 | PatternRewriter &rewriter) const override { |
34 | rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( |
35 | op, op.getType(), op.getDynamicSizes()); |
36 | return success(); |
37 | } |
38 | }; |
39 | |
40 | struct EmptyTensorToAllocTensor |
41 | : public bufferization::impl::EmptyTensorToAllocTensorPassBase< |
42 | EmptyTensorToAllocTensor> { |
43 | void runOnOperation() override; |
44 | |
45 | void getDependentDialects(DialectRegistry ®istry) const override { |
46 | registry |
47 | .insert<tensor::TensorDialect, bufferization::BufferizationDialect>(); |
48 | } |
49 | }; |
50 | } // namespace |
51 | |
52 | void bufferization::populateEmptyTensorToAllocTensorPattern( |
53 | RewritePatternSet &patterns) { |
54 | patterns.insert<EmptyTensorLoweringPattern>(arg: patterns.getContext()); |
55 | } |
56 | |
57 | void EmptyTensorToAllocTensor::runOnOperation() { |
58 | Operation *op = getOperation(); |
59 | RewritePatternSet patterns(op->getContext()); |
60 | populateEmptyTensorToAllocTensorPattern(patterns); |
61 | if (failed(applyPatternsGreedily(op, std::move(patterns)))) |
62 | signalPassFailure(); |
63 | } |
64 |