1//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass legalizes Tosa operations to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/Index/IR/IndexDialect.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Tosa/IR/TosaOps.h"
23#include "mlir/Dialect/Tosa/Transforms/Passes.h"
24#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Pass/PassManager.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "mlir/Transforms/Passes.h"
30
31namespace mlir {
32#define GEN_PASS_DEF_TOSATOLINALG
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37
38namespace {
39struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
40public:
41 void getDependentDialects(DialectRegistry &registry) const override {
42 registry
43 .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
44 index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
45 }
46
47 void runOnOperation() override {
48 RewritePatternSet patterns(&getContext());
49 ConversionTarget target(getContext());
50 target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
51 scf::SCFDialect>();
52 target.addIllegalDialect<tosa::TosaDialect>();
53
54 // Not every TOSA op can be legalized to linalg.
55 target.addLegalOp<tosa::ApplyScaleOp>();
56 target.addLegalOp<tosa::IfOp>();
57 target.addLegalOp<tosa::ConstOp>();
58 target.addLegalOp<tosa::WhileOp>();
59 target.addLegalOp<tosa::ConcatOp>();
60 target.addLegalOp<tosa::SliceOp>();
61 target.addLegalOp<tosa::ReshapeOp>();
62 target.addLegalOp<tosa::PadOp>();
63
64 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
65
66 FunctionOpInterface func = getOperation();
67 mlir::tosa::populateTosaToLinalgConversionPatterns(patterns: &patterns);
68 if (failed(applyFullConversion(func, target, std::move(patterns))))
69 signalPassFailure();
70 }
71};
72} // namespace
73
74std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
75 return std::make_unique<TosaToLinalg>();
76}
77
78void mlir::tosa::addTosaToLinalgPasses(
79 OpPassManager &pm, const TosaToLinalgOptions &options,
80 const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
81 tosa::TosaValidationOptions const &validationOptions) {
82 // Optional decompositions are designed to benefit linalg.
83 if (!options.disableTosaDecompositions)
84 pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
85 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
86
87 pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
88 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
89 pm.addNestedPass<func::FuncOp>(
90 tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
91 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
92 // TODO: Remove pass that operates on const tensor and enable optionality
93 pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
94 {options.aggressiveReduceConstant}));
95 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
96 pm.addPass(tosa::createTosaValidation(validationOptions));
97 pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
98}
99
100//===----------------------------------------------------------------------===//
101// Pipeline registration.
102//===----------------------------------------------------------------------===//
103
104void mlir::tosa::registerTosaToLinalgPipelines() {
105 PassPipelineRegistration<>(
106 "tosa-to-linalg-pipeline",
107 "The default pipeline for converting TOSA operators to the equivalent "
108 "operations using the tensor operations in LinAlg as well as LinAlg "
109 "named operations.",
110 [](OpPassManager &pm) {
111 TosaToLinalgOptions tosaToLinalgOptions;
112 TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
113 tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
114 tosaToLinalgNamedOptions,
115 /* validationOptions = */
116 {tosa::TosaProfileEnum::BaseInference,
117 /* StrictOperationSpecAlignment = */ true,
118 tosa::TosaLevelEnum::EightK});
119 });
120}
121

source code of mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp