1//===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements constant folding transformations on TOSA operations
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Tosa/Transforms/Passes.h"
14
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17
18namespace mlir {
19namespace tosa {
20#define GEN_PASS_DEF_TOSALAYERWISECONSTANTFOLDPASS
21#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
22} // namespace tosa
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::tosa;
27
28namespace {
29
30template <typename... Args>
31void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
32 (Args::getCanonicalizationPatterns(patterns, ctx), ...);
33}
34
35void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
36 RewritePatternSet &patterns) {
37 addOpsCanonicalizations<
38#define GET_OP_LIST
39#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
40 >(ctx, patterns);
41}
42
43struct TosaLayerwiseConstantFoldPass
44 : public tosa::impl::TosaLayerwiseConstantFoldPassBase<
45 TosaLayerwiseConstantFoldPass> {
46 using Base::Base;
47
48 void runOnOperation() override {
49 auto *ctx = &getContext();
50 RewritePatternSet patterns(ctx);
51 auto func = getOperation();
52
53 mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
54 mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
55 mlir::tosa::populateTosaConstantReduction(ctx, patterns,
56 aggressiveReduceConstant);
57 populateTosaOpsCanonicalizationPatterns(ctx, patterns);
58
59 if (applyPatternsGreedily(op: func, patterns: std::move(patterns)).failed())
60 signalPassFailure();
61 }
62};
63
64} // namespace
65

source code of mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp