1//===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/Utils.h"
10#include "mlir/Dialect/Transform/IR/TransformDialect.h"
11#include "mlir/IR/Verifier.h"
12#include "mlir/Interfaces/FunctionInterfaces.h"
13#include "llvm/Support/Debug.h"
14
15using namespace mlir;
16
17#define DEBUG_TYPE "transform-dialect-utils"
18#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19
20/// Return whether `func1` can be merged into `func2`. For that to work
21/// `func1` has to be a declaration (aka has to be external) and `func2`
22/// either has to be a declaration as well, or it has to be public (otherwise,
23/// it wouldn't be visible by `func1`).
24static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
25 return func1.isExternal() && (func2.isPublic() || func2.isExternal());
26}
27
28/// Merge `func1` into `func2`. The two ops must be inside the same parent op
29/// and mergable according to `canMergeInto`. The function erases `func1` such
30/// that only `func2` exists when the function returns.
31static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
32 FunctionOpInterface func2) {
33 assert(canMergeInto(func1, func2));
34 assert(func1->getParentOp() == func2->getParentOp() &&
35 "expected func1 and func2 to be in the same parent op");
36
37 // Check that function signatures match.
38 if (func1.getFunctionType() != func2.getFunctionType()) {
39 return func1.emitError()
40 << "external definition has a mismatching signature ("
41 << func2.getFunctionType() << ")";
42 }
43
44 // Check and merge argument attributes.
45 MLIRContext *context = func1->getContext();
46 auto *td = context->getLoadedDialect<transform::TransformDialect>();
47 StringAttr consumedName = td->getConsumedAttrName();
48 StringAttr readOnlyName = td->getReadOnlyAttrName();
49 for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
50 bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
51 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
52 bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
53 bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
54 if (!isExternalConsumed && !isExternalReadonly) {
55 if (isConsumed)
56 func2.setArgAttr(i, consumedName, UnitAttr::get(context));
57 else if (isReadonly)
58 func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
59 continue;
60 }
61
62 if ((isExternalConsumed && !isConsumed) ||
63 (isExternalReadonly && !isReadonly)) {
64 return func1.emitError()
65 << "external definition has mismatching consumption "
66 "annotations for argument #"
67 << i;
68 }
69 }
70
71 // `func1` is the external one, so we can remove it.
72 assert(func1.isExternal());
73 func1->erase();
74
75 return InFlightDiagnostic();
76}
77
78InFlightDiagnostic
79transform::detail::mergeSymbolsInto(Operation *target,
80 OwningOpRef<Operation *> other) {
81 assert(target->hasTrait<OpTrait::SymbolTable>() &&
82 "requires target to implement the 'SymbolTable' trait");
83 assert(other->hasTrait<OpTrait::SymbolTable>() &&
84 "requires target to implement the 'SymbolTable' trait");
85
86 SymbolTable targetSymbolTable(target);
87 SymbolTable otherSymbolTable(*other);
88
89 // Step 1:
90 //
91 // Rename private symbols in both ops in order to resolve conflicts that can
92 // be resolved that way.
93 LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
94 // TODO: Do we *actually* need to test in both directions?
95 for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
96 t: SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
97 u: SmallVector<SymbolTable *, 2>{&otherSymbolTable,
98 &targetSymbolTable})) {
99 Operation *symbolTableOp = symbolTable->getOp();
100 for (Operation &op : symbolTableOp->getRegion(index: 0).front()) {
101 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
102 if (!symbolOp)
103 continue;
104 StringAttr name = symbolOp.getNameAttr();
105 LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
106
107 // Check if there is a colliding op in the other module.
108 auto collidingOp =
109 cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
110 if (!collidingOp)
111 continue;
112
113 LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
114
115 // Collisions are fine if both opt are functions and can be merged.
116 if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
117 collidingFuncOp =
118 dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
119 funcOp && collidingFuncOp) {
120 if (canMergeInto(funcOp, collidingFuncOp) ||
121 canMergeInto(collidingFuncOp, funcOp)) {
122 LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
123 "will be merged\n");
124 continue;
125 }
126
127 // If they can't be merged, proceed like any other collision.
128 LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
129 }
130
131 // Collision can be resolved by renaming if one of the ops is private.
132 auto renameToUnique =
133 [&](SymbolOpInterface op, SymbolOpInterface otherOp,
134 SymbolTable &symbolTable,
135 SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
136 LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
137 FailureOr<StringAttr> maybeNewName =
138 symbolTable.renameToUnique(op, {&otherSymbolTable});
139 if (failed(maybeNewName)) {
140 InFlightDiagnostic diag = op->emitError("failed to rename symbol");
141 diag.attachNote(noteLoc: otherOp->getLoc())
142 << "attempted renaming due to collision with this op";
143 return diag;
144 }
145 LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
146 << "\n");
147 return InFlightDiagnostic();
148 };
149
150 if (symbolOp.isPrivate()) {
151 InFlightDiagnostic diag = renameToUnique(
152 symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
153 if (failed(result: diag))
154 return diag;
155 continue;
156 }
157 if (collidingOp.isPrivate()) {
158 InFlightDiagnostic diag = renameToUnique(
159 collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
160 if (failed(result: diag))
161 return diag;
162 continue;
163 }
164 LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
165 InFlightDiagnostic diag = symbolOp.emitError()
166 << "doubly defined symbol @" << name.getValue();
167 diag.attachNote(noteLoc: collidingOp->getLoc()) << "previously defined here";
168 return diag;
169 }
170 }
171
172 // TODO: This duplicates pass infrastructure. We should split this pass into
173 // several and let the pass infrastructure do the verification.
174 for (auto *op : SmallVector<Operation *>{target, *other}) {
175 if (failed(result: mlir::verify(op)))
176 return op->emitError() << "failed to verify input op after renaming";
177 }
178
179 // Step 2:
180 //
181 // Move all ops from `other` into target and merge public symbols.
182 LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
183 {
184 SmallVector<SymbolOpInterface> opsToMove;
185 for (Operation &op : other->getRegion(index: 0).front()) {
186 if (auto symbol = dyn_cast<SymbolOpInterface>(op))
187 opsToMove.push_back(symbol);
188 }
189
190 for (SymbolOpInterface op : opsToMove) {
191 // Remember potentially colliding op in the target module.
192 auto collidingOp = cast_or_null<SymbolOpInterface>(
193 targetSymbolTable.lookup(op.getNameAttr()));
194
195 // Move op even if we get a collision.
196 LLVM_DEBUG(DBGS() << " moving @" << op.getName());
197 op->moveBefore(&target->getRegion(0).front(),
198 target->getRegion(0).front().end());
199
200 // If there is no collision, we are done.
201 if (!collidingOp) {
202 LLVM_DEBUG(llvm::dbgs() << " without collision\n");
203 continue;
204 }
205
206 // The two colliding ops must both be functions because we have already
207 // emitted errors otherwise earlier.
208 auto funcOp = cast<FunctionOpInterface>(op.getOperation());
209 auto collidingFuncOp =
210 cast<FunctionOpInterface>(collidingOp.getOperation());
211
212 // Both ops are in the target module now and can be treated
213 // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
214 // `collidingFuncOp`.
215 if (!canMergeInto(funcOp, collidingFuncOp)) {
216 std::swap(funcOp, collidingFuncOp);
217 }
218 assert(canMergeInto(funcOp, collidingFuncOp));
219
220 LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
221 << collidingFuncOp.getLoc() << ":\n"
222 << collidingFuncOp << "\n");
223
224 // Update symbol table. This works with or without the previous `swap`.
225 targetSymbolTable.remove(funcOp);
226 targetSymbolTable.insert(collidingFuncOp);
227 assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
228
229 // Do the actual merging.
230 {
231 InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
232 if (failed(diag))
233 return diag;
234 }
235 }
236 }
237
238 if (failed(result: mlir::verify(op: target)))
239 return target->emitError()
240 << "failed to verify target op after merging symbols";
241
242 LLVM_DEBUG(DBGS() << "done merging ops\n");
243 return InFlightDiagnostic();
244}
245

source code of mlir/lib/Dialect/Transform/IR/Utils.cpp