1 | //===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/IR/Utils.h" |
10 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
11 | #include "mlir/IR/Verifier.h" |
12 | #include "mlir/Interfaces/FunctionInterfaces.h" |
13 | #include "llvm/Support/Debug.h" |
14 | |
15 | using namespace mlir; |
16 | |
17 | #define DEBUG_TYPE "transform-dialect-utils" |
18 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
19 | |
20 | /// Return whether `func1` can be merged into `func2`. For that to work |
21 | /// `func1` has to be a declaration (aka has to be external) and `func2` |
22 | /// either has to be a declaration as well, or it has to be public (otherwise, |
23 | /// it wouldn't be visible by `func1`). |
24 | static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { |
25 | return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
26 | } |
27 | |
28 | /// Merge `func1` into `func2`. The two ops must be inside the same parent op |
29 | /// and mergable according to `canMergeInto`. The function erases `func1` such |
30 | /// that only `func2` exists when the function returns. |
31 | static InFlightDiagnostic mergeInto(FunctionOpInterface func1, |
32 | FunctionOpInterface func2) { |
33 | assert(canMergeInto(func1, func2)); |
34 | assert(func1->getParentOp() == func2->getParentOp() && |
35 | "expected func1 and func2 to be in the same parent op" ); |
36 | |
37 | // Check that function signatures match. |
38 | if (func1.getFunctionType() != func2.getFunctionType()) { |
39 | return func1.emitError() |
40 | << "external definition has a mismatching signature (" |
41 | << func2.getFunctionType() << ")" ; |
42 | } |
43 | |
44 | // Check and merge argument attributes. |
45 | MLIRContext *context = func1->getContext(); |
46 | auto *td = context->getLoadedDialect<transform::TransformDialect>(); |
47 | StringAttr consumedName = td->getConsumedAttrName(); |
48 | StringAttr readOnlyName = td->getReadOnlyAttrName(); |
49 | for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { |
50 | bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; |
51 | bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; |
52 | bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr; |
53 | bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr; |
54 | if (!isExternalConsumed && !isExternalReadonly) { |
55 | if (isConsumed) |
56 | func2.setArgAttr(i, consumedName, UnitAttr::get(context)); |
57 | else if (isReadonly) |
58 | func2.setArgAttr(i, readOnlyName, UnitAttr::get(context)); |
59 | continue; |
60 | } |
61 | |
62 | if ((isExternalConsumed && !isConsumed) || |
63 | (isExternalReadonly && !isReadonly)) { |
64 | return func1.emitError() |
65 | << "external definition has mismatching consumption " |
66 | "annotations for argument #" |
67 | << i; |
68 | } |
69 | } |
70 | |
71 | // `func1` is the external one, so we can remove it. |
72 | assert(func1.isExternal()); |
73 | func1->erase(); |
74 | |
75 | return InFlightDiagnostic(); |
76 | } |
77 | |
78 | InFlightDiagnostic |
79 | transform::detail::mergeSymbolsInto(Operation *target, |
80 | OwningOpRef<Operation *> other) { |
81 | assert(target->hasTrait<OpTrait::SymbolTable>() && |
82 | "requires target to implement the 'SymbolTable' trait" ); |
83 | assert(other->hasTrait<OpTrait::SymbolTable>() && |
84 | "requires target to implement the 'SymbolTable' trait" ); |
85 | |
86 | SymbolTable targetSymbolTable(target); |
87 | SymbolTable otherSymbolTable(*other); |
88 | |
89 | // Step 1: |
90 | // |
91 | // Rename private symbols in both ops in order to resolve conflicts that can |
92 | // be resolved that way. |
93 | LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n" ); |
94 | // TODO: Do we *actually* need to test in both directions? |
95 | for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( |
96 | t: SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable}, |
97 | u: SmallVector<SymbolTable *, 2>{&otherSymbolTable, |
98 | &targetSymbolTable})) { |
99 | Operation *symbolTableOp = symbolTable->getOp(); |
100 | for (Operation &op : symbolTableOp->getRegion(index: 0).front()) { |
101 | auto symbolOp = dyn_cast<SymbolOpInterface>(op); |
102 | if (!symbolOp) |
103 | continue; |
104 | StringAttr name = symbolOp.getNameAttr(); |
105 | LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n" ); |
106 | |
107 | // Check if there is a colliding op in the other module. |
108 | auto collidingOp = |
109 | cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name)); |
110 | if (!collidingOp) |
111 | continue; |
112 | |
113 | LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); |
114 | |
115 | // Collisions are fine if both opt are functions and can be merged. |
116 | if (auto funcOp = dyn_cast<FunctionOpInterface>(op), |
117 | collidingFuncOp = |
118 | dyn_cast<FunctionOpInterface>(collidingOp.getOperation()); |
119 | funcOp && collidingFuncOp) { |
120 | if (canMergeInto(funcOp, collidingFuncOp) || |
121 | canMergeInto(collidingFuncOp, funcOp)) { |
122 | LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " |
123 | "will be merged\n" ); |
124 | continue; |
125 | } |
126 | |
127 | // If they can't be merged, proceed like any other collision. |
128 | LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions" ); |
129 | } |
130 | |
131 | // Collision can be resolved by renaming if one of the ops is private. |
132 | auto renameToUnique = |
133 | [&](SymbolOpInterface op, SymbolOpInterface otherOp, |
134 | SymbolTable &symbolTable, |
135 | SymbolTable &otherSymbolTable) -> InFlightDiagnostic { |
136 | LLVM_DEBUG(llvm::dbgs() << ", renaming\n" ); |
137 | FailureOr<StringAttr> maybeNewName = |
138 | symbolTable.renameToUnique(op, {&otherSymbolTable}); |
139 | if (failed(maybeNewName)) { |
140 | InFlightDiagnostic diag = op->emitError("failed to rename symbol" ); |
141 | diag.attachNote(noteLoc: otherOp->getLoc()) |
142 | << "attempted renaming due to collision with this op" ; |
143 | return diag; |
144 | } |
145 | LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() |
146 | << "\n" ); |
147 | return InFlightDiagnostic(); |
148 | }; |
149 | |
150 | if (symbolOp.isPrivate()) { |
151 | InFlightDiagnostic diag = renameToUnique( |
152 | symbolOp, collidingOp, *symbolTable, *otherSymbolTable); |
153 | if (failed(result: diag)) |
154 | return diag; |
155 | continue; |
156 | } |
157 | if (collidingOp.isPrivate()) { |
158 | InFlightDiagnostic diag = renameToUnique( |
159 | collidingOp, symbolOp, *otherSymbolTable, *symbolTable); |
160 | if (failed(result: diag)) |
161 | return diag; |
162 | continue; |
163 | } |
164 | LLVM_DEBUG(llvm::dbgs() << ", emitting error\n" ); |
165 | InFlightDiagnostic diag = symbolOp.emitError() |
166 | << "doubly defined symbol @" << name.getValue(); |
167 | diag.attachNote(noteLoc: collidingOp->getLoc()) << "previously defined here" ; |
168 | return diag; |
169 | } |
170 | } |
171 | |
172 | // TODO: This duplicates pass infrastructure. We should split this pass into |
173 | // several and let the pass infrastructure do the verification. |
174 | for (auto *op : SmallVector<Operation *>{target, *other}) { |
175 | if (failed(result: mlir::verify(op))) |
176 | return op->emitError() << "failed to verify input op after renaming" ; |
177 | } |
178 | |
179 | // Step 2: |
180 | // |
181 | // Move all ops from `other` into target and merge public symbols. |
182 | LLVM_DEBUG(DBGS() << "moving all symbols into target\n" ); |
183 | { |
184 | SmallVector<SymbolOpInterface> opsToMove; |
185 | for (Operation &op : other->getRegion(index: 0).front()) { |
186 | if (auto symbol = dyn_cast<SymbolOpInterface>(op)) |
187 | opsToMove.push_back(symbol); |
188 | } |
189 | |
190 | for (SymbolOpInterface op : opsToMove) { |
191 | // Remember potentially colliding op in the target module. |
192 | auto collidingOp = cast_or_null<SymbolOpInterface>( |
193 | targetSymbolTable.lookup(op.getNameAttr())); |
194 | |
195 | // Move op even if we get a collision. |
196 | LLVM_DEBUG(DBGS() << " moving @" << op.getName()); |
197 | op->moveBefore(&target->getRegion(0).front(), |
198 | target->getRegion(0).front().end()); |
199 | |
200 | // If there is no collision, we are done. |
201 | if (!collidingOp) { |
202 | LLVM_DEBUG(llvm::dbgs() << " without collision\n" ); |
203 | continue; |
204 | } |
205 | |
206 | // The two colliding ops must both be functions because we have already |
207 | // emitted errors otherwise earlier. |
208 | auto funcOp = cast<FunctionOpInterface>(op.getOperation()); |
209 | auto collidingFuncOp = |
210 | cast<FunctionOpInterface>(collidingOp.getOperation()); |
211 | |
212 | // Both ops are in the target module now and can be treated |
213 | // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into |
214 | // `collidingFuncOp`. |
215 | if (!canMergeInto(funcOp, collidingFuncOp)) { |
216 | std::swap(funcOp, collidingFuncOp); |
217 | } |
218 | assert(canMergeInto(funcOp, collidingFuncOp)); |
219 | |
220 | LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " |
221 | << collidingFuncOp.getLoc() << ":\n" |
222 | << collidingFuncOp << "\n" ); |
223 | |
224 | // Update symbol table. This works with or without the previous `swap`. |
225 | targetSymbolTable.remove(funcOp); |
226 | targetSymbolTable.insert(collidingFuncOp); |
227 | assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp); |
228 | |
229 | // Do the actual merging. |
230 | { |
231 | InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp); |
232 | if (failed(diag)) |
233 | return diag; |
234 | } |
235 | } |
236 | } |
237 | |
238 | if (failed(result: mlir::verify(op: target))) |
239 | return target->emitError() |
240 | << "failed to verify target op after merging symbols" ; |
241 | |
242 | LLVM_DEBUG(DBGS() << "done merging ops\n" ); |
243 | return InFlightDiagnostic(); |
244 | } |
245 | |