1//===- TosaOptionalDecompositions.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pass to apply the Tosa operations decompositions
10// exposed as populate functions in
11// include/mlir/Dialect/Tosa/Transforms/Passes.h
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Tosa/Transforms/Passes.h"
16
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22namespace mlir {
23namespace tosa {
24#define GEN_PASS_DEF_TOSAOPTIONALDECOMPOSITIONS
25#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
26} // namespace tosa
27} // namespace mlir
28
29using namespace mlir;
30
31namespace {
32
33struct TosaOptionalDecompositions
34 : public tosa::impl::TosaOptionalDecompositionsBase<
35 TosaOptionalDecompositions> {
36 void runOnOperation() override {
37 auto *ctx = &getContext();
38 RewritePatternSet patterns(ctx);
39 auto func = getOperation();
40
41 mlir::tosa::populateTosaDecomposeConv2D(ctx: ctx, patterns);
42 mlir::tosa::populateTosaDecomposeTransposeConv(ctx: ctx, patterns);
43 mlir::tosa::populateTosaDecomposeDepthwise(ctx: ctx, patterns);
44
45 if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
46 signalPassFailure();
47 }
48};
49
50} // namespace
51
52std::unique_ptr<Pass> mlir::tosa::createTosaOptionalDecompositions() {
53 return std::make_unique<TosaOptionalDecompositions>();
54}
55

source code of mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp