1//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass legalizes Tosa operations to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/Index/IR/IndexDialect.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Tosa/IR/TosaOps.h"
23#include "mlir/Dialect/Tosa/Transforms/Passes.h"
24#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Pass/PassManager.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "mlir/Transforms/Passes.h"
30
31namespace mlir {
32#define GEN_PASS_DEF_TOSATOLINALG
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37
38namespace {
39struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
40public:
41 void getDependentDialects(DialectRegistry &registry) const override {
42 registry
43 .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
44 index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
45 }
46
47 void runOnOperation() override {
48 RewritePatternSet patterns(&getContext());
49 ConversionTarget target(getContext());
50 target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
51 scf::SCFDialect>();
52 target.addIllegalDialect<tosa::TosaDialect>();
53
54 // Not every TOSA op can be legalized to linalg.
55 target.addLegalOp<tosa::ApplyScaleOp>();
56 target.addLegalOp<tosa::IfOp>();
57 target.addLegalOp<tosa::ConstOp>();
58 target.addLegalOp<tosa::ConstShapeOp>();
59 target.addLegalOp<tosa::WhileOp>();
60 target.addLegalOp<tosa::ConcatOp>();
61 target.addLegalOp<tosa::SliceOp>();
62 target.addLegalOp<tosa::ReshapeOp>();
63 target.addLegalOp<tosa::PadOp>();
64
65 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
66
67 TypeConverter converter;
68 tosa::populateTosaTypeConversion(converter);
69
70 FunctionOpInterface func = getOperation();
71 mlir::tosa::populateTosaToLinalgConversionPatterns(converter, patterns: &patterns);
72 if (failed(applyFullConversion(func, target, std::move(patterns))))
73 signalPassFailure();
74 }
75};
76} // namespace
77
78std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
79 return std::make_unique<TosaToLinalg>();
80}
81
82void mlir::tosa::addTosaToLinalgPasses(
83 OpPassManager &pm, const TosaToLinalgOptions &options,
84 const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
85 std::optional<tosa::TosaValidationOptions> validationOptions) {
86 // Optional decompositions are designed to benefit linalg.
87 if (!options.disableTosaDecompositions)
88 pm.addNestedPass<func::FuncOp>(
89 tosa::createTosaOptionalDecompositionsPass());
90 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
91
92 pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
93 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
94 pm.addNestedPass<func::FuncOp>(
95 tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
96 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
97 // TODO: Remove pass that operates on const tensor and enable optionality
98 pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
99 {options.aggressiveReduceConstant}));
100 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
101 if (validationOptions)
102 pm.addPass(tosa::createTosaValidation(*validationOptions));
103 pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
104}
105
106//===----------------------------------------------------------------------===//
107// Pipeline registration.
108//===----------------------------------------------------------------------===//
109
110void mlir::tosa::registerTosaToLinalgPipelines() {
111 PassPipelineRegistration<>(
112 "tosa-to-linalg-pipeline",
113 "The default pipeline for converting TOSA operators to the equivalent "
114 "operations using the tensor operations in LinAlg as well as LinAlg "
115 "named operations.",
116 [](OpPassManager &pm) {
117 TosaToLinalgOptions tosaToLinalgOptions;
118 TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
119 TosaValidationOptions validationOptions;
120 validationOptions.profile = {"none"};
121 validationOptions.extension = {"none"};
122 validationOptions.strictOpSpecAlignment = false;
123 validationOptions.allowInvalidOpDatatypeCombinations = false;
124 validationOptions.level = tosa::TosaLevelEnum::EightK;
125 tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
126 tosaToLinalgNamedOptions,
127 validationOptions);
128 });
129}
130

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp