1//===- InitTensorToAllocTensor.cpp - Lower tensor.empty to alloc_tensor ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17namespace mlir {
18namespace bufferization {
19#define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSORPASS
20#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
21} // namespace bufferization
22} // namespace mlir
23
24using namespace mlir;
25using namespace mlir::bufferization;
26using namespace mlir::tensor;
27
28namespace {
29struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> {
30 using OpRewritePattern<tensor::EmptyOp>::OpRewritePattern;
31
32 LogicalResult matchAndRewrite(tensor::EmptyOp op,
33 PatternRewriter &rewriter) const override {
34 rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
35 op, op.getType(), op.getDynamicSizes());
36 return success();
37 }
38};
39
40struct EmptyTensorToAllocTensor
41 : public bufferization::impl::EmptyTensorToAllocTensorPassBase<
42 EmptyTensorToAllocTensor> {
43 void runOnOperation() override;
44
45 void getDependentDialects(DialectRegistry &registry) const override {
46 registry
47 .insert<tensor::TensorDialect, bufferization::BufferizationDialect>();
48 }
49};
50} // namespace
51
52void bufferization::populateEmptyTensorToAllocTensorPattern(
53 RewritePatternSet &patterns) {
54 patterns.insert<EmptyTensorLoweringPattern>(arg: patterns.getContext());
55}
56
57void EmptyTensorToAllocTensor::runOnOperation() {
58 Operation *op = getOperation();
59 RewritePatternSet patterns(op->getContext());
60 populateEmptyTensorToAllocTensorPattern(patterns);
61 if (failed(applyPatternsGreedily(op, std::move(patterns))))
62 signalPassFailure();
63}
64

source code of mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp