| 1 | //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines various operation fold utilities. These utilities are |
| 10 | // intended to be used by passes to unify and simply their logic. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Transforms/FoldUtils.h" |
| 15 | |
| 16 | #include "mlir/IR/Builders.h" |
| 17 | #include "mlir/IR/Matchers.h" |
| 18 | #include "mlir/IR/Operation.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | |
| 22 | /// Given an operation, find the parent region that folded constants should be |
| 23 | /// inserted into. |
| 24 | static Region * |
| 25 | getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces, |
| 26 | Block *insertionBlock) { |
| 27 | while (Region *region = insertionBlock->getParent()) { |
| 28 | // Insert in this region for any of the following scenarios: |
| 29 | // * The parent is unregistered, or is known to be isolated from above. |
| 30 | // * The parent is a top-level operation. |
| 31 | auto *parentOp = region->getParentOp(); |
| 32 | if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || |
| 33 | !parentOp->getBlock()) |
| 34 | return region; |
| 35 | |
| 36 | // Otherwise, check if this region is a desired insertion region. |
| 37 | auto *interface = interfaces.getInterfaceFor(obj: parentOp); |
| 38 | if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) |
| 39 | return region; |
| 40 | |
| 41 | // Traverse up the parent looking for an insertion region. |
| 42 | insertionBlock = parentOp->getBlock(); |
| 43 | } |
| 44 | llvm_unreachable("expected valid insertion region" ); |
| 45 | } |
| 46 | |
| 47 | /// A utility function used to materialize a constant for a given attribute and |
| 48 | /// type. On success, a valid constant value is returned. Otherwise, null is |
| 49 | /// returned |
| 50 | static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, |
| 51 | Attribute value, Type type, |
| 52 | Location loc) { |
| 53 | auto insertPt = builder.getInsertionPoint(); |
| 54 | (void)insertPt; |
| 55 | |
| 56 | // Ask the dialect to materialize a constant operation for this value. |
| 57 | if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { |
| 58 | assert(insertPt == builder.getInsertionPoint()); |
| 59 | assert(matchPattern(constOp, m_Constant())); |
| 60 | return constOp; |
| 61 | } |
| 62 | |
| 63 | return nullptr; |
| 64 | } |
| 65 | |
| 66 | //===----------------------------------------------------------------------===// |
| 67 | // OperationFolder |
| 68 | //===----------------------------------------------------------------------===// |
| 69 | |
| 70 | LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) { |
| 71 | if (inPlaceUpdate) |
| 72 | *inPlaceUpdate = false; |
| 73 | |
| 74 | // If this is a unique'd constant, return failure as we know that it has |
| 75 | // already been folded. |
| 76 | if (isFolderOwnedConstant(op)) { |
| 77 | // Check to see if we should rehoist, i.e. if a non-constant operation was |
| 78 | // inserted before this one. |
| 79 | Block *opBlock = op->getBlock(); |
| 80 | if (&opBlock->front() != op && !isFolderOwnedConstant(op: op->getPrevNode())) { |
| 81 | op->moveBefore(existingOp: &opBlock->front()); |
| 82 | op->setLoc(erasedFoldedLocation); |
| 83 | } |
| 84 | return failure(); |
| 85 | } |
| 86 | |
| 87 | // Try to fold the operation. |
| 88 | SmallVector<Value, 8> results; |
| 89 | if (failed(Result: tryToFold(op, results))) |
| 90 | return failure(); |
| 91 | |
| 92 | // Check to see if the operation was just updated in place. |
| 93 | if (results.empty()) { |
| 94 | if (inPlaceUpdate) |
| 95 | *inPlaceUpdate = true; |
| 96 | if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>( |
| 97 | Val: rewriter.getListener())) { |
| 98 | // Folding API does not notify listeners, so we have to notify manually. |
| 99 | rewriteListener->notifyOperationModified(op); |
| 100 | } |
| 101 | return success(); |
| 102 | } |
| 103 | |
| 104 | // Constant folding succeeded. Replace all of the result values and erase the |
| 105 | // operation. |
| 106 | notifyRemoval(op); |
| 107 | rewriter.replaceOp(op, newValues: results); |
| 108 | return success(); |
| 109 | } |
| 110 | |
| 111 | bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) { |
| 112 | Block *opBlock = op->getBlock(); |
| 113 | |
| 114 | // If this is a constant we unique'd, we don't need to insert, but we can |
| 115 | // check to see if we should rehoist it. |
| 116 | if (isFolderOwnedConstant(op)) { |
| 117 | if (&opBlock->front() != op && !isFolderOwnedConstant(op: op->getPrevNode())) { |
| 118 | op->moveBefore(existingOp: &opBlock->front()); |
| 119 | op->setLoc(erasedFoldedLocation); |
| 120 | } |
| 121 | return true; |
| 122 | } |
| 123 | |
| 124 | // Get the constant value of the op if necessary. |
| 125 | if (!constValue) { |
| 126 | matchPattern(op, pattern: m_Constant(bind_value: &constValue)); |
| 127 | assert(constValue && "expected `op` to be a constant" ); |
| 128 | } else { |
| 129 | // Ensure that the provided constant was actually correct. |
| 130 | #ifndef NDEBUG |
| 131 | Attribute expectedValue; |
| 132 | matchPattern(op, pattern: m_Constant(bind_value: &expectedValue)); |
| 133 | assert( |
| 134 | expectedValue == constValue && |
| 135 | "provided constant value was not the expected value of the constant" ); |
| 136 | #endif |
| 137 | } |
| 138 | |
| 139 | // Check for an existing constant operation for the attribute value. |
| 140 | Region *insertRegion = getInsertionRegion(interfaces, insertionBlock: opBlock); |
| 141 | auto &uniquedConstants = foldScopes[insertRegion]; |
| 142 | Operation *&folderConstOp = uniquedConstants[std::make_tuple( |
| 143 | args: op->getDialect(), args&: constValue, args: *op->result_type_begin())]; |
| 144 | |
| 145 | // If there is an existing constant, replace `op`. |
| 146 | if (folderConstOp) { |
| 147 | notifyRemoval(op); |
| 148 | rewriter.replaceOp(op, newValues: folderConstOp->getResults()); |
| 149 | folderConstOp->setLoc(erasedFoldedLocation); |
| 150 | return false; |
| 151 | } |
| 152 | |
| 153 | // Otherwise, we insert `op`. If `op` is in the insertion block and is either |
| 154 | // already at the front of the block, or the previous operation is already a |
| 155 | // constant we unique'd (i.e. one we inserted), then we don't need to do |
| 156 | // anything. Otherwise, we move the constant to the insertion block. |
| 157 | Block *insertBlock = &insertRegion->front(); |
| 158 | if (opBlock != insertBlock || (&insertBlock->front() != op && |
| 159 | !isFolderOwnedConstant(op: op->getPrevNode()))) { |
| 160 | op->moveBefore(existingOp: &insertBlock->front()); |
| 161 | op->setLoc(erasedFoldedLocation); |
| 162 | } |
| 163 | |
| 164 | folderConstOp = op; |
| 165 | referencedDialects[op].push_back(Elt: op->getDialect()); |
| 166 | return true; |
| 167 | } |
| 168 | |
| 169 | /// Notifies that the given constant `op` should be remove from this |
| 170 | /// OperationFolder's internal bookkeeping. |
| 171 | void OperationFolder::notifyRemoval(Operation *op) { |
| 172 | // Check to see if this operation is uniqued within the folder. |
| 173 | auto it = referencedDialects.find(Val: op); |
| 174 | if (it == referencedDialects.end()) |
| 175 | return; |
| 176 | |
| 177 | // Get the constant value for this operation, this is the value that was used |
| 178 | // to unique the operation internally. |
| 179 | Attribute constValue; |
| 180 | matchPattern(op, pattern: m_Constant(bind_value: &constValue)); |
| 181 | assert(constValue); |
| 182 | |
| 183 | // Get the constant map that this operation was uniqued in. |
| 184 | auto &uniquedConstants = |
| 185 | foldScopes[getInsertionRegion(interfaces, insertionBlock: op->getBlock())]; |
| 186 | |
| 187 | // Erase all of the references to this operation. |
| 188 | auto type = op->getResult(idx: 0).getType(); |
| 189 | for (auto *dialect : it->second) |
| 190 | uniquedConstants.erase(Val: std::make_tuple(args&: dialect, args&: constValue, args&: type)); |
| 191 | referencedDialects.erase(I: it); |
| 192 | } |
| 193 | |
| 194 | /// Clear out any constants cached inside of the folder. |
| 195 | void OperationFolder::clear() { |
| 196 | foldScopes.clear(); |
| 197 | referencedDialects.clear(); |
| 198 | } |
| 199 | |
| 200 | /// Get or create a constant using the given builder. On success this returns |
| 201 | /// the constant operation, nullptr otherwise. |
| 202 | Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect, |
| 203 | Attribute value, Type type) { |
| 204 | // Find an insertion point for the constant. |
| 205 | auto *insertRegion = getInsertionRegion(interfaces, insertionBlock: block); |
| 206 | auto &entry = insertRegion->front(); |
| 207 | rewriter.setInsertionPointToStart(&entry); |
| 208 | |
| 209 | // Get the constant map for the insertion region of this operation. |
| 210 | // Use erased location since the op is being built at the front of block. |
| 211 | auto &uniquedConstants = foldScopes[insertRegion]; |
| 212 | Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value, |
| 213 | type, erasedFoldedLocation); |
| 214 | return constOp ? constOp->getResult(idx: 0) : Value(); |
| 215 | } |
| 216 | |
| 217 | bool OperationFolder::isFolderOwnedConstant(Operation *op) const { |
| 218 | return referencedDialects.count(Val: op); |
| 219 | } |
| 220 | |
| 221 | /// Tries to perform folding on the given `op`. If successful, populates |
| 222 | /// `results` with the results of the folding. |
| 223 | LogicalResult OperationFolder::tryToFold(Operation *op, |
| 224 | SmallVectorImpl<Value> &results) { |
| 225 | SmallVector<OpFoldResult, 8> foldResults; |
| 226 | if (failed(Result: op->fold(results&: foldResults)) || |
| 227 | failed(Result: processFoldResults(op, results, foldResults))) |
| 228 | return failure(); |
| 229 | return success(); |
| 230 | } |
| 231 | |
| 232 | LogicalResult |
| 233 | OperationFolder::processFoldResults(Operation *op, |
| 234 | SmallVectorImpl<Value> &results, |
| 235 | ArrayRef<OpFoldResult> foldResults) { |
| 236 | // Check to see if the operation was just updated in place. |
| 237 | if (foldResults.empty()) |
| 238 | return success(); |
| 239 | assert(foldResults.size() == op->getNumResults()); |
| 240 | |
| 241 | // Create a builder to insert new operations into the entry block of the |
| 242 | // insertion region. |
| 243 | auto *insertRegion = getInsertionRegion(interfaces, insertionBlock: op->getBlock()); |
| 244 | auto &entry = insertRegion->front(); |
| 245 | rewriter.setInsertionPointToStart(&entry); |
| 246 | |
| 247 | // Get the constant map for the insertion region of this operation. |
| 248 | auto &uniquedConstants = foldScopes[insertRegion]; |
| 249 | |
| 250 | // Create the result constants and replace the results. |
| 251 | auto *dialect = op->getDialect(); |
| 252 | for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { |
| 253 | assert(!foldResults[i].isNull() && "expected valid OpFoldResult" ); |
| 254 | |
| 255 | // Check if the result was an SSA value. |
| 256 | if (auto repl = llvm::dyn_cast_if_present<Value>(Val: foldResults[i])) { |
| 257 | results.emplace_back(Args&: repl); |
| 258 | continue; |
| 259 | } |
| 260 | |
| 261 | // Check to see if there is a canonicalized version of this constant. |
| 262 | auto res = op->getResult(idx: i); |
| 263 | Attribute attrRepl = cast<Attribute>(Val: foldResults[i]); |
| 264 | if (auto *constOp = |
| 265 | tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl, |
| 266 | res.getType(), erasedFoldedLocation)) { |
| 267 | // Ensure that this constant dominates the operation we are replacing it |
| 268 | // with. This may not automatically happen if the operation being folded |
| 269 | // was inserted before the constant within the insertion block. |
| 270 | Block *opBlock = op->getBlock(); |
| 271 | if (opBlock == constOp->getBlock() && &opBlock->front() != constOp) |
| 272 | constOp->moveBefore(&opBlock->front()); |
| 273 | |
| 274 | results.push_back(Elt: constOp->getResult(0)); |
| 275 | continue; |
| 276 | } |
| 277 | // If materialization fails, cleanup any operations generated for the |
| 278 | // previous results and return failure. |
| 279 | for (Operation &op : llvm::make_early_inc_range( |
| 280 | Range: llvm::make_range(x: entry.begin(), y: rewriter.getInsertionPoint()))) { |
| 281 | notifyRemoval(op: &op); |
| 282 | rewriter.eraseOp(op: &op); |
| 283 | } |
| 284 | |
| 285 | results.clear(); |
| 286 | return failure(); |
| 287 | } |
| 288 | |
| 289 | return success(); |
| 290 | } |
| 291 | |
| 292 | /// Try to get or create a new constant entry. On success this returns the |
| 293 | /// constant operation value, nullptr otherwise. |
| 294 | Operation * |
| 295 | OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants, |
| 296 | Dialect *dialect, Attribute value, |
| 297 | Type type, Location loc) { |
| 298 | // Check if an existing mapping already exists. |
| 299 | auto constKey = std::make_tuple(args&: dialect, args&: value, args&: type); |
| 300 | Operation *&constOp = uniquedConstants[constKey]; |
| 301 | if (constOp) { |
| 302 | if (loc != constOp->getLoc()) |
| 303 | constOp->setLoc(erasedFoldedLocation); |
| 304 | return constOp; |
| 305 | } |
| 306 | |
| 307 | // If one doesn't exist, try to materialize one. |
| 308 | if (!(constOp = materializeConstant(dialect, builder&: rewriter, value, type, loc))) |
| 309 | return nullptr; |
| 310 | |
| 311 | // Check to see if the generated constant is in the expected dialect. |
| 312 | auto *newDialect = constOp->getDialect(); |
| 313 | if (newDialect == dialect) { |
| 314 | referencedDialects[constOp].push_back(Elt: dialect); |
| 315 | return constOp; |
| 316 | } |
| 317 | |
| 318 | // If it isn't, then we also need to make sure that the mapping for the new |
| 319 | // dialect is valid. |
| 320 | auto newKey = std::make_tuple(args&: newDialect, args&: value, args&: type); |
| 321 | |
| 322 | // If an existing operation in the new dialect already exists, delete the |
| 323 | // materialized operation in favor of the existing one. |
| 324 | if (auto *existingOp = uniquedConstants.lookup(Val: newKey)) { |
| 325 | notifyRemoval(op: constOp); |
| 326 | rewriter.eraseOp(op: constOp); |
| 327 | referencedDialects[existingOp].push_back(Elt: dialect); |
| 328 | if (loc != existingOp->getLoc()) |
| 329 | existingOp->setLoc(erasedFoldedLocation); |
| 330 | return constOp = existingOp; |
| 331 | } |
| 332 | |
| 333 | // Otherwise, update the new dialect to the materialized operation. |
| 334 | referencedDialects[constOp].assign(IL: {dialect, newDialect}); |
| 335 | auto newIt = uniquedConstants.insert(KV: {newKey, constOp}); |
| 336 | return newIt.first->second; |
| 337 | } |
| 338 | |