1 | //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines various operation fold utilities. These utilities are |
10 | // intended to be used by passes to unify and simply their logic. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Transforms/FoldUtils.h" |
15 | |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/Matchers.h" |
18 | #include "mlir/IR/Operation.h" |
19 | |
20 | using namespace mlir; |
21 | |
22 | /// Given an operation, find the parent region that folded constants should be |
23 | /// inserted into. |
24 | static Region * |
25 | getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces, |
26 | Block *insertionBlock) { |
27 | while (Region *region = insertionBlock->getParent()) { |
28 | // Insert in this region for any of the following scenarios: |
29 | // * The parent is unregistered, or is known to be isolated from above. |
30 | // * The parent is a top-level operation. |
31 | auto *parentOp = region->getParentOp(); |
32 | if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || |
33 | !parentOp->getBlock()) |
34 | return region; |
35 | |
36 | // Otherwise, check if this region is a desired insertion region. |
37 | auto *interface = interfaces.getInterfaceFor(obj: parentOp); |
38 | if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) |
39 | return region; |
40 | |
41 | // Traverse up the parent looking for an insertion region. |
42 | insertionBlock = parentOp->getBlock(); |
43 | } |
44 | llvm_unreachable("expected valid insertion region" ); |
45 | } |
46 | |
47 | /// A utility function used to materialize a constant for a given attribute and |
48 | /// type. On success, a valid constant value is returned. Otherwise, null is |
49 | /// returned |
50 | static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, |
51 | Attribute value, Type type, |
52 | Location loc) { |
53 | auto insertPt = builder.getInsertionPoint(); |
54 | (void)insertPt; |
55 | |
56 | // Ask the dialect to materialize a constant operation for this value. |
57 | if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { |
58 | assert(insertPt == builder.getInsertionPoint()); |
59 | assert(matchPattern(constOp, m_Constant())); |
60 | return constOp; |
61 | } |
62 | |
63 | return nullptr; |
64 | } |
65 | |
66 | //===----------------------------------------------------------------------===// |
67 | // OperationFolder |
68 | //===----------------------------------------------------------------------===// |
69 | |
70 | LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) { |
71 | if (inPlaceUpdate) |
72 | *inPlaceUpdate = false; |
73 | |
74 | // If this is a unique'd constant, return failure as we know that it has |
75 | // already been folded. |
76 | if (isFolderOwnedConstant(op)) { |
77 | // Check to see if we should rehoist, i.e. if a non-constant operation was |
78 | // inserted before this one. |
79 | Block *opBlock = op->getBlock(); |
80 | if (&opBlock->front() != op && !isFolderOwnedConstant(op: op->getPrevNode())) { |
81 | op->moveBefore(existingOp: &opBlock->front()); |
82 | op->setLoc(erasedFoldedLocation); |
83 | } |
84 | return failure(); |
85 | } |
86 | |
87 | // Try to fold the operation. |
88 | SmallVector<Value, 8> results; |
89 | if (failed(result: tryToFold(op, results))) |
90 | return failure(); |
91 | |
92 | // Check to see if the operation was just updated in place. |
93 | if (results.empty()) { |
94 | if (inPlaceUpdate) |
95 | *inPlaceUpdate = true; |
96 | if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>( |
97 | Val: rewriter.getListener())) { |
98 | // Folding API does not notify listeners, so we have to notify manually. |
99 | rewriteListener->notifyOperationModified(op); |
100 | } |
101 | return success(); |
102 | } |
103 | |
104 | // Constant folding succeeded. Replace all of the result values and erase the |
105 | // operation. |
106 | notifyRemoval(op); |
107 | rewriter.replaceOp(op, newValues: results); |
108 | return success(); |
109 | } |
110 | |
111 | bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) { |
112 | Block *opBlock = op->getBlock(); |
113 | |
114 | // If this is a constant we unique'd, we don't need to insert, but we can |
115 | // check to see if we should rehoist it. |
116 | if (isFolderOwnedConstant(op)) { |
117 | if (&opBlock->front() != op && !isFolderOwnedConstant(op: op->getPrevNode())) { |
118 | op->moveBefore(existingOp: &opBlock->front()); |
119 | op->setLoc(erasedFoldedLocation); |
120 | } |
121 | return true; |
122 | } |
123 | |
124 | // Get the constant value of the op if necessary. |
125 | if (!constValue) { |
126 | matchPattern(op, pattern: m_Constant(bind_value: &constValue)); |
127 | assert(constValue && "expected `op` to be a constant" ); |
128 | } else { |
129 | // Ensure that the provided constant was actually correct. |
130 | #ifndef NDEBUG |
131 | Attribute expectedValue; |
132 | matchPattern(op, pattern: m_Constant(bind_value: &expectedValue)); |
133 | assert( |
134 | expectedValue == constValue && |
135 | "provided constant value was not the expected value of the constant" ); |
136 | #endif |
137 | } |
138 | |
139 | // Check for an existing constant operation for the attribute value. |
140 | Region *insertRegion = getInsertionRegion(interfaces, insertionBlock: opBlock); |
141 | auto &uniquedConstants = foldScopes[insertRegion]; |
142 | Operation *&folderConstOp = uniquedConstants[std::make_tuple( |
143 | args: op->getDialect(), args&: constValue, args: *op->result_type_begin())]; |
144 | |
145 | // If there is an existing constant, replace `op`. |
146 | if (folderConstOp) { |
147 | notifyRemoval(op); |
148 | rewriter.replaceOp(op, newValues: folderConstOp->getResults()); |
149 | folderConstOp->setLoc(erasedFoldedLocation); |
150 | return false; |
151 | } |
152 | |
153 | // Otherwise, we insert `op`. If `op` is in the insertion block and is either |
154 | // already at the front of the block, or the previous operation is already a |
155 | // constant we unique'd (i.e. one we inserted), then we don't need to do |
156 | // anything. Otherwise, we move the constant to the insertion block. |
157 | Block *insertBlock = &insertRegion->front(); |
158 | if (opBlock != insertBlock || (&insertBlock->front() != op && |
159 | !isFolderOwnedConstant(op: op->getPrevNode()))) { |
160 | op->moveBefore(existingOp: &insertBlock->front()); |
161 | op->setLoc(erasedFoldedLocation); |
162 | } |
163 | |
164 | folderConstOp = op; |
165 | referencedDialects[op].push_back(Elt: op->getDialect()); |
166 | return true; |
167 | } |
168 | |
169 | /// Notifies that the given constant `op` should be remove from this |
170 | /// OperationFolder's internal bookkeeping. |
171 | void OperationFolder::notifyRemoval(Operation *op) { |
172 | // Check to see if this operation is uniqued within the folder. |
173 | auto it = referencedDialects.find(Val: op); |
174 | if (it == referencedDialects.end()) |
175 | return; |
176 | |
177 | // Get the constant value for this operation, this is the value that was used |
178 | // to unique the operation internally. |
179 | Attribute constValue; |
180 | matchPattern(op, pattern: m_Constant(bind_value: &constValue)); |
181 | assert(constValue); |
182 | |
183 | // Get the constant map that this operation was uniqued in. |
184 | auto &uniquedConstants = |
185 | foldScopes[getInsertionRegion(interfaces, insertionBlock: op->getBlock())]; |
186 | |
187 | // Erase all of the references to this operation. |
188 | auto type = op->getResult(idx: 0).getType(); |
189 | for (auto *dialect : it->second) |
190 | uniquedConstants.erase(Val: std::make_tuple(args&: dialect, args&: constValue, args&: type)); |
191 | referencedDialects.erase(I: it); |
192 | } |
193 | |
194 | /// Clear out any constants cached inside of the folder. |
195 | void OperationFolder::clear() { |
196 | foldScopes.clear(); |
197 | referencedDialects.clear(); |
198 | } |
199 | |
200 | /// Get or create a constant using the given builder. On success this returns |
201 | /// the constant operation, nullptr otherwise. |
202 | Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect, |
203 | Attribute value, Type type) { |
204 | // Find an insertion point for the constant. |
205 | auto *insertRegion = getInsertionRegion(interfaces, insertionBlock: block); |
206 | auto &entry = insertRegion->front(); |
207 | rewriter.setInsertionPoint(block: &entry, insertPoint: entry.begin()); |
208 | |
209 | // Get the constant map for the insertion region of this operation. |
210 | // Use erased location since the op is being built at the front of block. |
211 | auto &uniquedConstants = foldScopes[insertRegion]; |
212 | Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value, |
213 | type, erasedFoldedLocation); |
214 | return constOp ? constOp->getResult(idx: 0) : Value(); |
215 | } |
216 | |
217 | bool OperationFolder::isFolderOwnedConstant(Operation *op) const { |
218 | return referencedDialects.count(Val: op); |
219 | } |
220 | |
221 | /// Tries to perform folding on the given `op`. If successful, populates |
222 | /// `results` with the results of the folding. |
223 | LogicalResult OperationFolder::tryToFold(Operation *op, |
224 | SmallVectorImpl<Value> &results) { |
225 | SmallVector<OpFoldResult, 8> foldResults; |
226 | if (failed(result: op->fold(results&: foldResults)) || |
227 | failed(result: processFoldResults(op, results, foldResults))) |
228 | return failure(); |
229 | return success(); |
230 | } |
231 | |
232 | LogicalResult |
233 | OperationFolder::processFoldResults(Operation *op, |
234 | SmallVectorImpl<Value> &results, |
235 | ArrayRef<OpFoldResult> foldResults) { |
236 | // Check to see if the operation was just updated in place. |
237 | if (foldResults.empty()) |
238 | return success(); |
239 | assert(foldResults.size() == op->getNumResults()); |
240 | |
241 | // Create a builder to insert new operations into the entry block of the |
242 | // insertion region. |
243 | auto *insertRegion = getInsertionRegion(interfaces, insertionBlock: op->getBlock()); |
244 | auto &entry = insertRegion->front(); |
245 | rewriter.setInsertionPoint(block: &entry, insertPoint: entry.begin()); |
246 | |
247 | // Get the constant map for the insertion region of this operation. |
248 | auto &uniquedConstants = foldScopes[insertRegion]; |
249 | |
250 | // Create the result constants and replace the results. |
251 | auto *dialect = op->getDialect(); |
252 | for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { |
253 | assert(!foldResults[i].isNull() && "expected valid OpFoldResult" ); |
254 | |
255 | // Check if the result was an SSA value. |
256 | if (auto repl = llvm::dyn_cast_if_present<Value>(Val: foldResults[i])) { |
257 | results.emplace_back(Args&: repl); |
258 | continue; |
259 | } |
260 | |
261 | // Check to see if there is a canonicalized version of this constant. |
262 | auto res = op->getResult(idx: i); |
263 | Attribute attrRepl = foldResults[i].get<Attribute>(); |
264 | if (auto *constOp = |
265 | tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl, |
266 | res.getType(), erasedFoldedLocation)) { |
267 | // Ensure that this constant dominates the operation we are replacing it |
268 | // with. This may not automatically happen if the operation being folded |
269 | // was inserted before the constant within the insertion block. |
270 | Block *opBlock = op->getBlock(); |
271 | if (opBlock == constOp->getBlock() && &opBlock->front() != constOp) |
272 | constOp->moveBefore(&opBlock->front()); |
273 | |
274 | results.push_back(Elt: constOp->getResult(0)); |
275 | continue; |
276 | } |
277 | // If materialization fails, cleanup any operations generated for the |
278 | // previous results and return failure. |
279 | for (Operation &op : llvm::make_early_inc_range( |
280 | Range: llvm::make_range(x: entry.begin(), y: rewriter.getInsertionPoint()))) { |
281 | notifyRemoval(op: &op); |
282 | rewriter.eraseOp(op: &op); |
283 | } |
284 | |
285 | results.clear(); |
286 | return failure(); |
287 | } |
288 | |
289 | return success(); |
290 | } |
291 | |
292 | /// Try to get or create a new constant entry. On success this returns the |
293 | /// constant operation value, nullptr otherwise. |
294 | Operation * |
295 | OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants, |
296 | Dialect *dialect, Attribute value, |
297 | Type type, Location loc) { |
298 | // Check if an existing mapping already exists. |
299 | auto constKey = std::make_tuple(args&: dialect, args&: value, args&: type); |
300 | Operation *&constOp = uniquedConstants[constKey]; |
301 | if (constOp) { |
302 | if (loc != constOp->getLoc()) |
303 | constOp->setLoc(erasedFoldedLocation); |
304 | return constOp; |
305 | } |
306 | |
307 | // If one doesn't exist, try to materialize one. |
308 | if (!(constOp = materializeConstant(dialect, builder&: rewriter, value, type, loc))) |
309 | return nullptr; |
310 | |
311 | // Check to see if the generated constant is in the expected dialect. |
312 | auto *newDialect = constOp->getDialect(); |
313 | if (newDialect == dialect) { |
314 | referencedDialects[constOp].push_back(Elt: dialect); |
315 | return constOp; |
316 | } |
317 | |
318 | // If it isn't, then we also need to make sure that the mapping for the new |
319 | // dialect is valid. |
320 | auto newKey = std::make_tuple(args&: newDialect, args&: value, args&: type); |
321 | |
322 | // If an existing operation in the new dialect already exists, delete the |
323 | // materialized operation in favor of the existing one. |
324 | if (auto *existingOp = uniquedConstants.lookup(Val: newKey)) { |
325 | notifyRemoval(op: constOp); |
326 | rewriter.eraseOp(op: constOp); |
327 | referencedDialects[existingOp].push_back(Elt: dialect); |
328 | if (loc != existingOp->getLoc()) |
329 | existingOp->setLoc(erasedFoldedLocation); |
330 | return constOp = existingOp; |
331 | } |
332 | |
333 | // Otherwise, update the new dialect to the materialized operation. |
334 | referencedDialects[constOp].assign(IL: {dialect, newDialect}); |
335 | auto newIt = uniquedConstants.insert(KV: {newKey, constOp}); |
336 | return newIt.first->second; |
337 | } |
338 | |