1//===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/Func/Transforms/Passes.h"
11
12namespace mlir {
13namespace {
14
15#define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
16#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
17
18// Define a notion of function equivalence that allows for reuse. Ignore the
19// symbol name for this purpose.
20struct DuplicateFuncOpEquivalenceInfo
21 : public llvm::DenseMapInfo<func::FuncOp> {
22
23 static unsigned getHashValue(const func::FuncOp cFunc) {
24 if (!cFunc) {
25 return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
26 }
27
28 // Aggregate attributes, ignoring the symbol name.
29 llvm::hash_code hash = {};
30 func::FuncOp func = const_cast<func::FuncOp &>(cFunc);
31 StringAttr symNameAttrName = func.getSymNameAttrName();
32 for (NamedAttribute namedAttr : cFunc->getAttrs()) {
33 StringAttr attrName = namedAttr.getName();
34 if (attrName == symNameAttrName)
35 continue;
36 hash = llvm::hash_combine(hash, namedAttr);
37 }
38
39 // Also hash the func body.
40 func.getBody().walk([&](Operation *op) {
41 hash = llvm::hash_combine(
42 hash, OperationEquivalence::computeHash(
43 op, /*hashOperands=*/OperationEquivalence::ignoreHashValue,
44 /*hashResults=*/OperationEquivalence::ignoreHashValue,
45 OperationEquivalence::IgnoreLocations));
46 });
47
48 return hash;
49 }
50
51 static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
52 if (lhs == rhs)
53 return true;
54 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
55 rhs == getTombstoneKey() || rhs == getEmptyKey())
56 return false;
57 // Check discardable attributes equivalence
58 if (lhs->getDiscardableAttrDictionary() !=
59 rhs->getDiscardableAttrDictionary())
60 return false;
61
62 // Check properties equivalence, ignoring the symbol name.
63 // Make a copy, so that we can erase the symbol name and perform the
64 // comparison.
65 auto pLhs = lhs.getProperties();
66 auto pRhs = rhs.getProperties();
67 pLhs.sym_name = nullptr;
68 pRhs.sym_name = nullptr;
69 if (pLhs != pRhs)
70 return false;
71
72 // Compare inner workings.
73 return OperationEquivalence::isRegionEquivalentTo(
74 &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
75 }
76};
77
78struct DuplicateFunctionEliminationPass
79 : public impl::DuplicateFunctionEliminationPassBase<
80 DuplicateFunctionEliminationPass> {
81
82 using DuplicateFunctionEliminationPassBase<
83 DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
84
85 void runOnOperation() override {
86 auto module = getOperation();
87
88 // Find unique representant per equivalent func ops.
89 DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
90 DenseMap<StringAttr, func::FuncOp> getRepresentant;
91 DenseSet<func::FuncOp> toBeErased;
92 module.walk([&](func::FuncOp f) {
93 auto [repr, inserted] = uniqueFuncOps.insert(f);
94 getRepresentant[f.getSymNameAttr()] = *repr;
95 if (!inserted) {
96 toBeErased.insert(f);
97 }
98 });
99
100 // Update call ops to call unique func op representants.
101 module.walk([&](func::CallOp callOp) {
102 func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
103 callOp.setCallee(callee.getSymName());
104 });
105
106 // Erase redundant func ops.
107 for (auto it : toBeErased) {
108 it.erase();
109 }
110 }
111};
112
113} // namespace
114
115std::unique_ptr<Pass> mlir::func::createDuplicateFunctionEliminationPass() {
116 return std::make_unique<DuplicateFunctionEliminationPass>();
117}
118
119} // namespace mlir
120

source code of mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp