1//===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements constant folding transformations on TOSA operations
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Tosa/Transforms/Passes.h"
14
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20namespace mlir {
21namespace tosa {
22#define GEN_PASS_DEF_TOSALAYERWISECONSTANTFOLDPASS
23#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
24} // namespace tosa
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::tosa;
29
30namespace {
31
32template <typename... Args>
33void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
34 (Args::getCanonicalizationPatterns(patterns, ctx), ...);
35}
36
37void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
38 RewritePatternSet &patterns) {
39 addOpsCanonicalizations<
40#define GET_OP_LIST
41#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
42 >(ctx, patterns);
43}
44
45struct TosaLayerwiseConstantFoldPass
46 : public tosa::impl::TosaLayerwiseConstantFoldPassBase<
47 TosaLayerwiseConstantFoldPass> {
48 using Base::Base;
49
50 void runOnOperation() override {
51 auto *ctx = &getContext();
52 RewritePatternSet patterns(ctx);
53 auto func = getOperation();
54
55 mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx: ctx, patterns);
56 mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx: ctx, patterns);
57 mlir::tosa::populateTosaConstantReduction(ctx, patterns,
58 aggressiveReduceConstant);
59 populateTosaOpsCanonicalizationPatterns(ctx, patterns);
60
61 if (applyPatternsGreedily(func, std::move(patterns)).failed())
62 signalPassFailure();
63 }
64};
65
66} // namespace
67

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp