1 | //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header file declares various operation folding utilities. These |
10 | // utilities are intended to be used by passes to unify and simply their logic. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_TRANSFORMS_FOLDUTILS_H |
15 | #define MLIR_TRANSFORMS_FOLDUTILS_H |
16 | |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/Dialect.h" |
19 | #include "mlir/IR/DialectInterface.h" |
20 | #include "mlir/IR/PatternMatch.h" |
21 | #include "mlir/Interfaces/FoldInterfaces.h" |
22 | |
23 | namespace mlir { |
24 | class Operation; |
25 | class Value; |
26 | |
27 | //===--------------------------------------------------------------------===// |
28 | // OperationFolder |
29 | //===--------------------------------------------------------------------===// |
30 | |
31 | /// A utility class for folding operations, and unifying duplicated constants |
32 | /// generated along the way. |
33 | class OperationFolder { |
34 | public: |
35 | OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr) |
36 | : erasedFoldedLocation(UnknownLoc::get(ctx)), interfaces(ctx), |
37 | rewriter(ctx, listener) {} |
38 | |
39 | /// Tries to perform folding on the given `op`, including unifying |
40 | /// deduplicated constants. If successful, replaces `op`'s uses with |
41 | /// folded results, and returns success. If the op was completely folded it is |
42 | /// erased. If it is just updated in place, `inPlaceUpdate` is set to true. |
43 | LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr); |
44 | |
45 | /// Tries to fold a pre-existing constant operation. `constValue` represents |
46 | /// the value of the constant, and can be optionally passed if the value is |
47 | /// already known (e.g. if the constant was discovered by m_Constant). This is |
48 | /// purely an optimization opportunity for callers that already know the value |
49 | /// of the constant. Returns false if an existing constant for `op` already |
50 | /// exists in the folder, in which case `op` is replaced and erased. |
51 | /// Otherwise, returns true and `op` is inserted into the folder (and |
52 | /// hoisted if necessary). |
53 | bool insertKnownConstant(Operation *op, Attribute constValue = {}); |
54 | |
55 | /// Notifies that the given constant `op` should be remove from this |
56 | /// OperationFolder's internal bookkeeping. |
57 | /// |
58 | /// Note: this method must be called if a constant op is to be deleted |
59 | /// externally to this OperationFolder. `op` must be a constant op. |
60 | void notifyRemoval(Operation *op); |
61 | |
62 | /// Clear out any constants cached inside of the folder. |
63 | void clear(); |
64 | |
65 | /// Get or create a constant for use in the specified block. The constant may |
66 | /// be created in a parent block. On success this returns the constant |
67 | /// operation, nullptr otherwise. |
68 | Value getOrCreateConstant(Block *block, Dialect *dialect, Attribute value, |
69 | Type type); |
70 | |
71 | private: |
72 | /// This map keeps track of uniqued constants by dialect, attribute, and type. |
73 | /// A constant operation materializes an attribute with a type. Dialects may |
74 | /// generate different constants with the same input attribute and type, so we |
75 | /// also need to track per-dialect. |
76 | using ConstantMap = |
77 | DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>; |
78 | |
79 | /// Returns true if the given operation is an already folded constant that is |
80 | /// owned by this folder. |
81 | bool isFolderOwnedConstant(Operation *op) const; |
82 | |
83 | /// Tries to perform folding on the given `op`. If successful, populates |
84 | /// `results` with the results of the folding. |
85 | LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results); |
86 | |
87 | /// Try to process a set of fold results. Populates `results` on success, |
88 | /// otherwise leaves it unchanged. |
89 | LogicalResult processFoldResults(Operation *op, |
90 | SmallVectorImpl<Value> &results, |
91 | ArrayRef<OpFoldResult> foldResults); |
92 | |
93 | /// Try to get or create a new constant entry. On success this returns the |
94 | /// constant operation, nullptr otherwise. |
95 | Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants, |
96 | Dialect *dialect, Attribute value, |
97 | Type type, Location loc); |
98 | |
99 | /// The location to overwrite with for folder-owned constants. |
100 | UnknownLoc erasedFoldedLocation; |
101 | |
102 | /// A mapping between an insertion region and the constants that have been |
103 | /// created within it. |
104 | DenseMap<Region *, ConstantMap> foldScopes; |
105 | |
106 | /// This map tracks all of the dialects that an operation is referenced by; |
107 | /// given that many dialects may generate the same constant. |
108 | DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects; |
109 | |
110 | /// A collection of dialect folder interfaces. |
111 | DialectInterfaceCollection<DialectFoldInterface> interfaces; |
112 | |
113 | /// A rewriter that performs all IR modifications. |
114 | IRRewriter rewriter; |
115 | }; |
116 | |
117 | } // namespace mlir |
118 | |
119 | #endif // MLIR_TRANSFORMS_FOLDUTILS_H |
120 | |