1 | //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation pass legalizes Tosa operations to the Linalg dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
18 | #include "mlir/Dialect/Math/IR/Math.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
21 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
22 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
23 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
24 | #include "mlir/IR/PatternMatch.h" |
25 | #include "mlir/Pass/PassManager.h" |
26 | #include "mlir/Transforms/DialectConversion.h" |
27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
28 | |
29 | namespace mlir { |
30 | #define GEN_PASS_DEF_TOSATOLINALGNAMED |
31 | #include "mlir/Conversion/Passes.h.inc" |
32 | } // namespace mlir |
33 | |
34 | using namespace mlir; |
35 | |
36 | namespace { |
37 | struct TosaToLinalgNamed |
38 | : public impl::TosaToLinalgNamedBase<TosaToLinalgNamed> { |
39 | public: |
40 | TosaToLinalgNamed(const TosaToLinalgNamedOptions &options) |
41 | : impl::TosaToLinalgNamedBase<TosaToLinalgNamed>(options) {} |
42 | |
43 | void getDependentDialects(DialectRegistry ®istry) const override { |
44 | registry |
45 | .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect, |
46 | tensor::TensorDialect, scf::SCFDialect>(); |
47 | } |
48 | |
49 | void runOnOperation() override { |
50 | RewritePatternSet patterns(&getContext()); |
51 | ConversionTarget target(getContext()); |
52 | target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect, |
53 | tensor::TensorDialect, scf::SCFDialect>(); |
54 | |
55 | // Not every TOSA op can be legalized to linalg. |
56 | target.addIllegalOp<tosa::Conv2DOp>(); |
57 | target.addIllegalOp<tosa::Conv3DOp>(); |
58 | target.addIllegalOp<tosa::DepthwiseConv2DOp>(); |
59 | target.addIllegalOp<tosa::MaxPool2dOp>(); |
60 | target.addIllegalOp<tosa::AvgPool2dOp>(); |
61 | target.addIllegalOp<tosa::MatMulOp>(); |
62 | target.addIllegalOp<tosa::FullyConnectedOp>(); |
63 | target.addIllegalOp<tosa::TransposeOp>(); |
64 | |
65 | target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
66 | |
67 | FunctionOpInterface func = getOperation(); |
68 | TosaToLinalgNamedOptions options; |
69 | options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF; |
70 | tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options); |
71 | if (failed(applyFullConversion(func, target, std::move(patterns)))) |
72 | signalPassFailure(); |
73 | } |
74 | }; |
75 | } // namespace |
76 | |
77 | std::unique_ptr<Pass> |
78 | mlir::tosa::createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options) { |
79 | return std::make_unique<TosaToLinalgNamed>(options); |
80 | } |
81 | |