| 1 | //===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Transform/IR/Utils.h" |
| 10 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 11 | #include "mlir/IR/Verifier.h" |
| 12 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 13 | #include "llvm/Support/Debug.h" |
| 14 | |
| 15 | using namespace mlir; |
| 16 | |
| 17 | #define DEBUG_TYPE "transform-dialect-utils" |
| 18 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
| 19 | |
| 20 | /// Return whether `func1` can be merged into `func2`. For that to work |
| 21 | /// `func1` has to be a declaration (aka has to be external) and `func2` |
| 22 | /// either has to be a declaration as well, or it has to be public (otherwise, |
| 23 | /// it wouldn't be visible by `func1`). |
| 24 | static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { |
| 25 | return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
| 26 | } |
| 27 | |
| 28 | /// Merge `func1` into `func2`. The two ops must be inside the same parent op |
| 29 | /// and mergable according to `canMergeInto`. The function erases `func1` such |
| 30 | /// that only `func2` exists when the function returns. |
| 31 | static InFlightDiagnostic mergeInto(FunctionOpInterface func1, |
| 32 | FunctionOpInterface func2) { |
| 33 | assert(canMergeInto(func1, func2)); |
| 34 | assert(func1->getParentOp() == func2->getParentOp() && |
| 35 | "expected func1 and func2 to be in the same parent op" ); |
| 36 | |
| 37 | // Check that function signatures match. |
| 38 | if (func1.getFunctionType() != func2.getFunctionType()) { |
| 39 | return func1.emitError() |
| 40 | << "external definition has a mismatching signature (" |
| 41 | << func2.getFunctionType() << ")" ; |
| 42 | } |
| 43 | |
| 44 | // Check and merge argument attributes. |
| 45 | MLIRContext *context = func1->getContext(); |
| 46 | auto *td = context->getLoadedDialect<transform::TransformDialect>(); |
| 47 | StringAttr consumedName = td->getConsumedAttrName(); |
| 48 | StringAttr readOnlyName = td->getReadOnlyAttrName(); |
| 49 | for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { |
| 50 | bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; |
| 51 | bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; |
| 52 | bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr; |
| 53 | bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr; |
| 54 | if (!isExternalConsumed && !isExternalReadonly) { |
| 55 | if (isConsumed) |
| 56 | func2.setArgAttr(i, consumedName, UnitAttr::get(context)); |
| 57 | else if (isReadonly) |
| 58 | func2.setArgAttr(i, readOnlyName, UnitAttr::get(context)); |
| 59 | continue; |
| 60 | } |
| 61 | |
| 62 | if ((isExternalConsumed && !isConsumed) || |
| 63 | (isExternalReadonly && !isReadonly)) { |
| 64 | return func1.emitError() |
| 65 | << "external definition has mismatching consumption " |
| 66 | "annotations for argument #" |
| 67 | << i; |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | // `func1` is the external one, so we can remove it. |
| 72 | assert(func1.isExternal()); |
| 73 | func1->erase(); |
| 74 | |
| 75 | return InFlightDiagnostic(); |
| 76 | } |
| 77 | |
| 78 | InFlightDiagnostic |
| 79 | transform::detail::mergeSymbolsInto(Operation *target, |
| 80 | OwningOpRef<Operation *> other) { |
| 81 | assert(target->hasTrait<OpTrait::SymbolTable>() && |
| 82 | "requires target to implement the 'SymbolTable' trait" ); |
| 83 | assert(other->hasTrait<OpTrait::SymbolTable>() && |
| 84 | "requires target to implement the 'SymbolTable' trait" ); |
| 85 | |
| 86 | SymbolTable targetSymbolTable(target); |
| 87 | SymbolTable otherSymbolTable(*other); |
| 88 | |
| 89 | // Step 1: |
| 90 | // |
| 91 | // Rename private symbols in both ops in order to resolve conflicts that can |
| 92 | // be resolved that way. |
| 93 | LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n" ); |
| 94 | // TODO: Do we *actually* need to test in both directions? |
| 95 | for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( |
| 96 | t: SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable}, |
| 97 | u: SmallVector<SymbolTable *, 2>{&otherSymbolTable, |
| 98 | &targetSymbolTable})) { |
| 99 | Operation *symbolTableOp = symbolTable->getOp(); |
| 100 | for (Operation &op : symbolTableOp->getRegion(index: 0).front()) { |
| 101 | auto symbolOp = dyn_cast<SymbolOpInterface>(op); |
| 102 | if (!symbolOp) |
| 103 | continue; |
| 104 | StringAttr name = symbolOp.getNameAttr(); |
| 105 | LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n" ); |
| 106 | |
| 107 | // Check if there is a colliding op in the other module. |
| 108 | auto collidingOp = |
| 109 | cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name)); |
| 110 | if (!collidingOp) |
| 111 | continue; |
| 112 | |
| 113 | LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); |
| 114 | |
| 115 | // Collisions are fine if both opt are functions and can be merged. |
| 116 | if (auto funcOp = dyn_cast<FunctionOpInterface>(op), |
| 117 | collidingFuncOp = |
| 118 | dyn_cast<FunctionOpInterface>(collidingOp.getOperation()); |
| 119 | funcOp && collidingFuncOp) { |
| 120 | if (canMergeInto(funcOp, collidingFuncOp) || |
| 121 | canMergeInto(collidingFuncOp, funcOp)) { |
| 122 | LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " |
| 123 | "will be merged\n" ); |
| 124 | continue; |
| 125 | } |
| 126 | |
| 127 | // If they can't be merged, proceed like any other collision. |
| 128 | LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions" ); |
| 129 | } |
| 130 | |
| 131 | // Collision can be resolved by renaming if one of the ops is private. |
| 132 | auto renameToUnique = |
| 133 | [&](SymbolOpInterface op, SymbolOpInterface otherOp, |
| 134 | SymbolTable &symbolTable, |
| 135 | SymbolTable &otherSymbolTable) -> InFlightDiagnostic { |
| 136 | LLVM_DEBUG(llvm::dbgs() << ", renaming\n" ); |
| 137 | FailureOr<StringAttr> maybeNewName = |
| 138 | symbolTable.renameToUnique(op, {&otherSymbolTable}); |
| 139 | if (failed(maybeNewName)) { |
| 140 | InFlightDiagnostic diag = op->emitError("failed to rename symbol" ); |
| 141 | diag.attachNote(noteLoc: otherOp->getLoc()) |
| 142 | << "attempted renaming due to collision with this op" ; |
| 143 | return diag; |
| 144 | } |
| 145 | LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() |
| 146 | << "\n" ); |
| 147 | return InFlightDiagnostic(); |
| 148 | }; |
| 149 | |
| 150 | if (symbolOp.isPrivate()) { |
| 151 | InFlightDiagnostic diag = renameToUnique( |
| 152 | symbolOp, collidingOp, *symbolTable, *otherSymbolTable); |
| 153 | if (failed(Result: diag)) |
| 154 | return diag; |
| 155 | continue; |
| 156 | } |
| 157 | if (collidingOp.isPrivate()) { |
| 158 | InFlightDiagnostic diag = renameToUnique( |
| 159 | collidingOp, symbolOp, *otherSymbolTable, *symbolTable); |
| 160 | if (failed(Result: diag)) |
| 161 | return diag; |
| 162 | continue; |
| 163 | } |
| 164 | LLVM_DEBUG(llvm::dbgs() << ", emitting error\n" ); |
| 165 | InFlightDiagnostic diag = symbolOp.emitError() |
| 166 | << "doubly defined symbol @" << name.getValue(); |
| 167 | diag.attachNote(noteLoc: collidingOp->getLoc()) << "previously defined here" ; |
| 168 | return diag; |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | // TODO: This duplicates pass infrastructure. We should split this pass into |
| 173 | // several and let the pass infrastructure do the verification. |
| 174 | for (auto *op : SmallVector<Operation *>{target, *other}) { |
| 175 | if (failed(Result: mlir::verify(op))) |
| 176 | return op->emitError() << "failed to verify input op after renaming" ; |
| 177 | } |
| 178 | |
| 179 | // Step 2: |
| 180 | // |
| 181 | // Move all ops from `other` into target and merge public symbols. |
| 182 | LLVM_DEBUG(DBGS() << "moving all symbols into target\n" ); |
| 183 | { |
| 184 | SmallVector<SymbolOpInterface> opsToMove; |
| 185 | for (Operation &op : other->getRegion(index: 0).front()) { |
| 186 | if (auto symbol = dyn_cast<SymbolOpInterface>(op)) |
| 187 | opsToMove.push_back(symbol); |
| 188 | } |
| 189 | |
| 190 | for (SymbolOpInterface op : opsToMove) { |
| 191 | // Remember potentially colliding op in the target module. |
| 192 | auto collidingOp = cast_or_null<SymbolOpInterface>( |
| 193 | targetSymbolTable.lookup(op.getNameAttr())); |
| 194 | |
| 195 | // Move op even if we get a collision. |
| 196 | LLVM_DEBUG(DBGS() << " moving @" << op.getName()); |
| 197 | op->moveBefore(&target->getRegion(0).front(), |
| 198 | target->getRegion(0).front().end()); |
| 199 | |
| 200 | // If there is no collision, we are done. |
| 201 | if (!collidingOp) { |
| 202 | LLVM_DEBUG(llvm::dbgs() << " without collision\n" ); |
| 203 | continue; |
| 204 | } |
| 205 | |
| 206 | // The two colliding ops must both be functions because we have already |
| 207 | // emitted errors otherwise earlier. |
| 208 | auto funcOp = cast<FunctionOpInterface>(op.getOperation()); |
| 209 | auto collidingFuncOp = |
| 210 | cast<FunctionOpInterface>(collidingOp.getOperation()); |
| 211 | |
| 212 | // Both ops are in the target module now and can be treated |
| 213 | // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into |
| 214 | // `collidingFuncOp`. |
| 215 | if (!canMergeInto(funcOp, collidingFuncOp)) { |
| 216 | std::swap(funcOp, collidingFuncOp); |
| 217 | } |
| 218 | assert(canMergeInto(funcOp, collidingFuncOp)); |
| 219 | |
| 220 | LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " |
| 221 | << collidingFuncOp.getLoc() << ":\n" |
| 222 | << collidingFuncOp << "\n" ); |
| 223 | |
| 224 | // Update symbol table. This works with or without the previous `swap`. |
| 225 | targetSymbolTable.remove(funcOp); |
| 226 | targetSymbolTable.insert(collidingFuncOp); |
| 227 | assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp); |
| 228 | |
| 229 | // Do the actual merging. |
| 230 | { |
| 231 | InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp); |
| 232 | if (failed(diag)) |
| 233 | return diag; |
| 234 | } |
| 235 | } |
| 236 | } |
| 237 | |
| 238 | if (failed(Result: mlir::verify(op: target))) |
| 239 | return target->emitError() |
| 240 | << "failed to verify target op after merging symbols" ; |
| 241 | |
| 242 | LLVM_DEBUG(DBGS() << "done merging ops\n" ); |
| 243 | return InFlightDiagnostic(); |
| 244 | } |
| 245 | |