1//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines various operation fold utilities. These utilities are
10// intended to be used by passes to unify and simply their logic.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/FoldUtils.h"
15
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Operation.h"
19
20using namespace mlir;
21
22/// Given an operation, find the parent region that folded constants should be
23/// inserted into.
24static Region *
25getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
26 Block *insertionBlock) {
27 while (Region *region = insertionBlock->getParent()) {
28 // Insert in this region for any of the following scenarios:
29 // * The parent is unregistered, or is known to be isolated from above.
30 // * The parent is a top-level operation.
31 auto *parentOp = region->getParentOp();
32 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33 !parentOp->getBlock())
34 return region;
35
36 // Otherwise, check if this region is a desired insertion region.
37 auto *interface = interfaces.getInterfaceFor(obj: parentOp);
38 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39 return region;
40
41 // Traverse up the parent looking for an insertion region.
42 insertionBlock = parentOp->getBlock();
43 }
44 llvm_unreachable("expected valid insertion region");
45}
46
47/// A utility function used to materialize a constant for a given attribute and
48/// type. On success, a valid constant value is returned. Otherwise, null is
49/// returned
50static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51 Attribute value, Type type,
52 Location loc) {
53 auto insertPt = builder.getInsertionPoint();
54 (void)insertPt;
55
56 // Ask the dialect to materialize a constant operation for this value.
57 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58 assert(insertPt == builder.getInsertionPoint());
59 assert(matchPattern(constOp, m_Constant()));
60 return constOp;
61 }
62
63 return nullptr;
64}
65
66//===----------------------------------------------------------------------===//
67// OperationFolder
68//===----------------------------------------------------------------------===//
69
70LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
71 if (inPlaceUpdate)
72 *inPlaceUpdate = false;
73
74 // If this is a unique'd constant, return failure as we know that it has
75 // already been folded.
76 if (isFolderOwnedConstant(op)) {
77 // Check to see if we should rehoist, i.e. if a non-constant operation was
78 // inserted before this one.
79 Block *opBlock = op->getBlock();
80 if (&opBlock->front() != op && !isFolderOwnedConstant(op: op->getPrevNode())) {
81 op->moveBefore(existingOp: &opBlock->front());
82 op->setLoc(erasedFoldedLocation);
83 }
84 return failure();
85 }
86
87 // Try to fold the operation.
88 SmallVector<Value, 8> results;
89 if (failed(result: tryToFold(op, results)))
90 return failure();
91
92 // Check to see if the operation was just updated in place.
93 if (results.empty()) {
94 if (inPlaceUpdate)
95 *inPlaceUpdate = true;
96 if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
97 Val: rewriter.getListener())) {
98 // Folding API does not notify listeners, so we have to notify manually.
99 rewriteListener->notifyOperationModified(op);
100 }
101 return success();
102 }
103
104 // Constant folding succeeded. Replace all of the result values and erase the
105 // operation.
106 notifyRemoval(op);
107 rewriter.replaceOp(op, newValues: results);
108 return success();
109}
110
111bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
112 Block *opBlock = op->getBlock();
113
114 // If this is a constant we unique'd, we don't need to insert, but we can
115 // check to see if we should rehoist it.
116 if (isFolderOwnedConstant(op)) {
117 if (&opBlock->front() != op && !isFolderOwnedConstant(op: op->getPrevNode())) {
118 op->moveBefore(existingOp: &opBlock->front());
119 op->setLoc(erasedFoldedLocation);
120 }
121 return true;
122 }
123
124 // Get the constant value of the op if necessary.
125 if (!constValue) {
126 matchPattern(op, pattern: m_Constant(bind_value: &constValue));
127 assert(constValue && "expected `op` to be a constant");
128 } else {
129 // Ensure that the provided constant was actually correct.
130#ifndef NDEBUG
131 Attribute expectedValue;
132 matchPattern(op, pattern: m_Constant(bind_value: &expectedValue));
133 assert(
134 expectedValue == constValue &&
135 "provided constant value was not the expected value of the constant");
136#endif
137 }
138
139 // Check for an existing constant operation for the attribute value.
140 Region *insertRegion = getInsertionRegion(interfaces, insertionBlock: opBlock);
141 auto &uniquedConstants = foldScopes[insertRegion];
142 Operation *&folderConstOp = uniquedConstants[std::make_tuple(
143 args: op->getDialect(), args&: constValue, args: *op->result_type_begin())];
144
145 // If there is an existing constant, replace `op`.
146 if (folderConstOp) {
147 notifyRemoval(op);
148 rewriter.replaceOp(op, newValues: folderConstOp->getResults());
149 folderConstOp->setLoc(erasedFoldedLocation);
150 return false;
151 }
152
153 // Otherwise, we insert `op`. If `op` is in the insertion block and is either
154 // already at the front of the block, or the previous operation is already a
155 // constant we unique'd (i.e. one we inserted), then we don't need to do
156 // anything. Otherwise, we move the constant to the insertion block.
157 Block *insertBlock = &insertRegion->front();
158 if (opBlock != insertBlock || (&insertBlock->front() != op &&
159 !isFolderOwnedConstant(op: op->getPrevNode()))) {
160 op->moveBefore(existingOp: &insertBlock->front());
161 op->setLoc(erasedFoldedLocation);
162 }
163
164 folderConstOp = op;
165 referencedDialects[op].push_back(Elt: op->getDialect());
166 return true;
167}
168
169/// Notifies that the given constant `op` should be remove from this
170/// OperationFolder's internal bookkeeping.
171void OperationFolder::notifyRemoval(Operation *op) {
172 // Check to see if this operation is uniqued within the folder.
173 auto it = referencedDialects.find(Val: op);
174 if (it == referencedDialects.end())
175 return;
176
177 // Get the constant value for this operation, this is the value that was used
178 // to unique the operation internally.
179 Attribute constValue;
180 matchPattern(op, pattern: m_Constant(bind_value: &constValue));
181 assert(constValue);
182
183 // Get the constant map that this operation was uniqued in.
184 auto &uniquedConstants =
185 foldScopes[getInsertionRegion(interfaces, insertionBlock: op->getBlock())];
186
187 // Erase all of the references to this operation.
188 auto type = op->getResult(idx: 0).getType();
189 for (auto *dialect : it->second)
190 uniquedConstants.erase(Val: std::make_tuple(args&: dialect, args&: constValue, args&: type));
191 referencedDialects.erase(I: it);
192}
193
194/// Clear out any constants cached inside of the folder.
195void OperationFolder::clear() {
196 foldScopes.clear();
197 referencedDialects.clear();
198}
199
200/// Get or create a constant using the given builder. On success this returns
201/// the constant operation, nullptr otherwise.
202Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect,
203 Attribute value, Type type) {
204 // Find an insertion point for the constant.
205 auto *insertRegion = getInsertionRegion(interfaces, insertionBlock: block);
206 auto &entry = insertRegion->front();
207 rewriter.setInsertionPoint(block: &entry, insertPoint: entry.begin());
208
209 // Get the constant map for the insertion region of this operation.
210 // Use erased location since the op is being built at the front of block.
211 auto &uniquedConstants = foldScopes[insertRegion];
212 Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value,
213 type, erasedFoldedLocation);
214 return constOp ? constOp->getResult(idx: 0) : Value();
215}
216
217bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
218 return referencedDialects.count(Val: op);
219}
220
221/// Tries to perform folding on the given `op`. If successful, populates
222/// `results` with the results of the folding.
223LogicalResult OperationFolder::tryToFold(Operation *op,
224 SmallVectorImpl<Value> &results) {
225 SmallVector<OpFoldResult, 8> foldResults;
226 if (failed(result: op->fold(results&: foldResults)) ||
227 failed(result: processFoldResults(op, results, foldResults)))
228 return failure();
229 return success();
230}
231
232LogicalResult
233OperationFolder::processFoldResults(Operation *op,
234 SmallVectorImpl<Value> &results,
235 ArrayRef<OpFoldResult> foldResults) {
236 // Check to see if the operation was just updated in place.
237 if (foldResults.empty())
238 return success();
239 assert(foldResults.size() == op->getNumResults());
240
241 // Create a builder to insert new operations into the entry block of the
242 // insertion region.
243 auto *insertRegion = getInsertionRegion(interfaces, insertionBlock: op->getBlock());
244 auto &entry = insertRegion->front();
245 rewriter.setInsertionPoint(block: &entry, insertPoint: entry.begin());
246
247 // Get the constant map for the insertion region of this operation.
248 auto &uniquedConstants = foldScopes[insertRegion];
249
250 // Create the result constants and replace the results.
251 auto *dialect = op->getDialect();
252 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
253 assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
254
255 // Check if the result was an SSA value.
256 if (auto repl = llvm::dyn_cast_if_present<Value>(Val: foldResults[i])) {
257 results.emplace_back(Args&: repl);
258 continue;
259 }
260
261 // Check to see if there is a canonicalized version of this constant.
262 auto res = op->getResult(idx: i);
263 Attribute attrRepl = foldResults[i].get<Attribute>();
264 if (auto *constOp =
265 tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
266 res.getType(), erasedFoldedLocation)) {
267 // Ensure that this constant dominates the operation we are replacing it
268 // with. This may not automatically happen if the operation being folded
269 // was inserted before the constant within the insertion block.
270 Block *opBlock = op->getBlock();
271 if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
272 constOp->moveBefore(&opBlock->front());
273
274 results.push_back(Elt: constOp->getResult(0));
275 continue;
276 }
277 // If materialization fails, cleanup any operations generated for the
278 // previous results and return failure.
279 for (Operation &op : llvm::make_early_inc_range(
280 Range: llvm::make_range(x: entry.begin(), y: rewriter.getInsertionPoint()))) {
281 notifyRemoval(op: &op);
282 rewriter.eraseOp(op: &op);
283 }
284
285 results.clear();
286 return failure();
287 }
288
289 return success();
290}
291
292/// Try to get or create a new constant entry. On success this returns the
293/// constant operation value, nullptr otherwise.
294Operation *
295OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
296 Dialect *dialect, Attribute value,
297 Type type, Location loc) {
298 // Check if an existing mapping already exists.
299 auto constKey = std::make_tuple(args&: dialect, args&: value, args&: type);
300 Operation *&constOp = uniquedConstants[constKey];
301 if (constOp) {
302 if (loc != constOp->getLoc())
303 constOp->setLoc(erasedFoldedLocation);
304 return constOp;
305 }
306
307 // If one doesn't exist, try to materialize one.
308 if (!(constOp = materializeConstant(dialect, builder&: rewriter, value, type, loc)))
309 return nullptr;
310
311 // Check to see if the generated constant is in the expected dialect.
312 auto *newDialect = constOp->getDialect();
313 if (newDialect == dialect) {
314 referencedDialects[constOp].push_back(Elt: dialect);
315 return constOp;
316 }
317
318 // If it isn't, then we also need to make sure that the mapping for the new
319 // dialect is valid.
320 auto newKey = std::make_tuple(args&: newDialect, args&: value, args&: type);
321
322 // If an existing operation in the new dialect already exists, delete the
323 // materialized operation in favor of the existing one.
324 if (auto *existingOp = uniquedConstants.lookup(Val: newKey)) {
325 notifyRemoval(op: constOp);
326 rewriter.eraseOp(op: constOp);
327 referencedDialects[existingOp].push_back(Elt: dialect);
328 if (loc != existingOp->getLoc())
329 existingOp->setLoc(erasedFoldedLocation);
330 return constOp = existingOp;
331 }
332
333 // Otherwise, update the new dialect to the materialized operation.
334 referencedDialects[constOp].assign(IL: {dialect, newDialect});
335 auto newIt = uniquedConstants.insert(KV: {newKey, constOp});
336 return newIt.first->second;
337}
338

source code of mlir/lib/Transforms/Utils/FoldUtils.cpp