1//===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements constant folding transformations on TOSA operations
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Tosa/Transforms/Passes.h"
14
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20namespace mlir {
21namespace tosa {
22#define GEN_PASS_DEF_TOSALAYERWISECONSTANTFOLDPASS
23#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
24} // namespace tosa
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::tosa;
29
30namespace {
31
32template <typename... Args>
33void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
34 (Args::getCanonicalizationPatterns(patterns, ctx), ...);
35}
36
37void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
38 RewritePatternSet &patterns) {
39 addOpsCanonicalizations<
40#define GET_OP_LIST
41#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
42 >(ctx, patterns);
43}
44
45struct TosaLayerwiseConstantFoldPass
46 : public tosa::impl::TosaLayerwiseConstantFoldPassBase<
47 TosaLayerwiseConstantFoldPass> {
48 TosaLayerwiseConstantFoldPass(
49 const TosaLayerwiseConstantFoldPassOptions &options)
50 : TosaLayerwiseConstantFoldPassBase(options) {}
51
52 void runOnOperation() override {
53 auto *ctx = &getContext();
54 RewritePatternSet patterns(ctx);
55 auto func = getOperation();
56
57 mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx: ctx, patterns);
58 mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx: ctx, patterns);
59 mlir::tosa::populateTosaConstantReduction(ctx, patterns,
60 aggressiveReduceConstant);
61 populateTosaOpsCanonicalizationPatterns(ctx, patterns);
62
63 if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
64 signalPassFailure();
65 }
66};
67
68} // namespace
69
70std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
71 return std::make_unique<TosaLayerwiseConstantFoldPass>(
72 TosaLayerwiseConstantFoldPassOptions{false});
73}
74
75std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass(
76 const TosaLayerwiseConstantFoldPassOptions &options) {
77 return std::make_unique<TosaLayerwiseConstantFoldPass>(options);
78}
79

source code of mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp