1//===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/Func/Transforms/Passes.h"
11
12namespace mlir {
13namespace func {
14#define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
15#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
16} // namespace func
17
18namespace {
19
20// Define a notion of function equivalence that allows for reuse. Ignore the
21// symbol name for this purpose.
22struct DuplicateFuncOpEquivalenceInfo
23 : public llvm::DenseMapInfo<func::FuncOp> {
24
25 static unsigned getHashValue(const func::FuncOp cFunc) {
26 if (!cFunc) {
27 return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
28 }
29
30 // Aggregate attributes, ignoring the symbol name.
31 llvm::hash_code hash = {};
32 func::FuncOp func = const_cast<func::FuncOp &>(cFunc);
33 StringAttr symNameAttrName = func.getSymNameAttrName();
34 for (NamedAttribute namedAttr : cFunc->getAttrs()) {
35 StringAttr attrName = namedAttr.getName();
36 if (attrName == symNameAttrName)
37 continue;
38 hash = llvm::hash_combine(hash, namedAttr);
39 }
40
41 // Also hash the func body.
42 func.getBody().walk([&](Operation *op) {
43 hash = llvm::hash_combine(
44 hash, OperationEquivalence::computeHash(
45 op, /*hashOperands=*/OperationEquivalence::ignoreHashValue,
46 /*hashResults=*/OperationEquivalence::ignoreHashValue,
47 OperationEquivalence::IgnoreLocations));
48 });
49
50 return hash;
51 }
52
53 static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
54 if (lhs == rhs)
55 return true;
56 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
57 rhs == getTombstoneKey() || rhs == getEmptyKey())
58 return false;
59
60 if (lhs.isDeclaration() || rhs.isDeclaration())
61 return false;
62
63 // Check discardable attributes equivalence
64 if (lhs->getDiscardableAttrDictionary() !=
65 rhs->getDiscardableAttrDictionary())
66 return false;
67
68 // Check properties equivalence, ignoring the symbol name.
69 // Make a copy, so that we can erase the symbol name and perform the
70 // comparison.
71 auto pLhs = lhs.getProperties();
72 auto pRhs = rhs.getProperties();
73 pLhs.sym_name = nullptr;
74 pRhs.sym_name = nullptr;
75 if (pLhs != pRhs)
76 return false;
77
78 // Compare inner workings.
79 return OperationEquivalence::isRegionEquivalentTo(
80 &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
81 }
82};
83
84struct DuplicateFunctionEliminationPass
85 : public func::impl::DuplicateFunctionEliminationPassBase<
86 DuplicateFunctionEliminationPass> {
87
88 using DuplicateFunctionEliminationPassBase<
89 DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
90
91 void runOnOperation() override {
92 auto module = getOperation();
93
94 // Find unique representant per equivalent func ops.
95 DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
96 DenseMap<StringAttr, func::FuncOp> getRepresentant;
97 DenseSet<func::FuncOp> toBeErased;
98 module.walk([&](func::FuncOp f) {
99 auto [repr, inserted] = uniqueFuncOps.insert(f);
100 getRepresentant[f.getSymNameAttr()] = *repr;
101 if (!inserted) {
102 toBeErased.insert(f);
103 }
104 });
105
106 // Update all symbol uses to reference unique func op
107 // representants and erase redundant func ops.
108 SymbolTableCollection symbolTable;
109 SymbolUserMap userMap(symbolTable, module);
110 for (auto it : toBeErased) {
111 StringAttr oldSymbol = it.getSymNameAttr();
112 StringAttr newSymbol = getRepresentant[oldSymbol].getSymNameAttr();
113 userMap.replaceAllUsesWith(it, newSymbol);
114 it.erase();
115 }
116 }
117};
118
119} // namespace
120} // namespace mlir
121

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp