1//===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert Tensor dialect to Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
14
15#include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTTENSORTOLINALG
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26
27namespace {
28/// A pass converting MLIR Tensor operations into the Linalg dialect.
29class ConvertTensorToLinalgPass
30 : public impl::ConvertTensorToLinalgBase<ConvertTensorToLinalgPass> {
31 void runOnOperation() override {
32 auto &context = getContext();
33 ConversionTarget target(context);
34 target
35 .addLegalDialect<mlir::arith::ArithDialect, mlir::linalg::LinalgDialect,
36 mlir::tensor::TensorDialect>();
37 target.addIllegalOp<mlir::tensor::PadOp>();
38
39 RewritePatternSet patterns(&context);
40 populateTensorToLinalgPatterns(patterns);
41
42 if (failed(applyPartialConversion(getOperation(), target,
43 std::move(patterns))))
44 return signalPassFailure();
45 }
46};
47} // namespace
48
49std::unique_ptr<OperationPass<ModuleOp>>
50mlir::createConvertTensorToLinalgPass() {
51 return std::make_unique<ConvertTensorToLinalgPass>();
52}
53

source code of mlir/lib/Conversion/TensorToLinalg/TensorToLinalgPass.cpp