1 | //===- Pattern.cpp - Pattern wrapper class --------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Pattern wrapper class to simplify using TableGen Record defining a MLIR |
10 | // Pattern. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include <utility> |
15 | |
16 | #include "mlir/TableGen/Pattern.h" |
17 | #include "llvm/ADT/StringExtras.h" |
18 | #include "llvm/ADT/Twine.h" |
19 | #include "llvm/Support/Debug.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "llvm/TableGen/Error.h" |
22 | #include "llvm/TableGen/Record.h" |
23 | |
24 | #define DEBUG_TYPE "mlir-tblgen-pattern" |
25 | |
26 | using namespace mlir; |
27 | using namespace tblgen; |
28 | |
29 | using llvm::formatv; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // DagLeaf |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | bool DagLeaf::isUnspecified() const { |
36 | return isa_and_nonnull<llvm::UnsetInit>(Val: def); |
37 | } |
38 | |
39 | bool DagLeaf::isOperandMatcher() const { |
40 | // Operand matchers specify a type constraint. |
41 | return isSubClassOf(superclass: "TypeConstraint" ); |
42 | } |
43 | |
44 | bool DagLeaf::isAttrMatcher() const { |
45 | // Attribute matchers specify an attribute constraint. |
46 | return isSubClassOf(superclass: "AttrConstraint" ); |
47 | } |
48 | |
49 | bool DagLeaf::isNativeCodeCall() const { |
50 | return isSubClassOf(superclass: "NativeCodeCall" ); |
51 | } |
52 | |
53 | bool DagLeaf::isConstantAttr() const { return isSubClassOf(superclass: "ConstantAttr" ); } |
54 | |
55 | bool DagLeaf::isEnumAttrCase() const { |
56 | return isSubClassOf(superclass: "EnumAttrCaseInfo" ); |
57 | } |
58 | |
59 | bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(Val: def); } |
60 | |
61 | Constraint DagLeaf::getAsConstraint() const { |
62 | assert((isOperandMatcher() || isAttrMatcher()) && |
63 | "the DAG leaf must be operand or attribute" ); |
64 | return Constraint(cast<llvm::DefInit>(Val: def)->getDef()); |
65 | } |
66 | |
67 | ConstantAttr DagLeaf::getAsConstantAttr() const { |
68 | assert(isConstantAttr() && "the DAG leaf must be constant attribute" ); |
69 | return ConstantAttr(cast<llvm::DefInit>(Val: def)); |
70 | } |
71 | |
72 | EnumAttrCase DagLeaf::getAsEnumAttrCase() const { |
73 | assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case" ); |
74 | return EnumAttrCase(cast<llvm::DefInit>(Val: def)); |
75 | } |
76 | |
77 | std::string DagLeaf::getConditionTemplate() const { |
78 | return getAsConstraint().getConditionTemplate(); |
79 | } |
80 | |
81 | llvm::StringRef DagLeaf::getNativeCodeTemplate() const { |
82 | assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall" ); |
83 | return cast<llvm::DefInit>(Val: def)->getDef()->getValueAsString(FieldName: "expression" ); |
84 | } |
85 | |
86 | int DagLeaf::getNumReturnsOfNativeCode() const { |
87 | assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall" ); |
88 | return cast<llvm::DefInit>(Val: def)->getDef()->getValueAsInt(FieldName: "numReturns" ); |
89 | } |
90 | |
91 | std::string DagLeaf::getStringAttr() const { |
92 | assert(isStringAttr() && "the DAG leaf must be string attribute" ); |
93 | return def->getAsUnquotedString(); |
94 | } |
95 | bool DagLeaf::isSubClassOf(StringRef superclass) const { |
96 | if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(Val: def)) |
97 | return defInit->getDef()->isSubClassOf(Name: superclass); |
98 | return false; |
99 | } |
100 | |
101 | void DagLeaf::print(raw_ostream &os) const { |
102 | if (def) |
103 | def->print(OS&: os); |
104 | } |
105 | |
106 | //===----------------------------------------------------------------------===// |
107 | // DagNode |
108 | //===----------------------------------------------------------------------===// |
109 | |
110 | bool DagNode::isNativeCodeCall() const { |
111 | if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(Val: node->getOperator())) |
112 | return defInit->getDef()->isSubClassOf(Name: "NativeCodeCall" ); |
113 | return false; |
114 | } |
115 | |
116 | bool DagNode::isOperation() const { |
117 | return !isNativeCodeCall() && !isReplaceWithValue() && |
118 | !isLocationDirective() && !isReturnTypeDirective() && !isEither() && |
119 | !isVariadic(); |
120 | } |
121 | |
122 | llvm::StringRef DagNode::getNativeCodeTemplate() const { |
123 | assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall" ); |
124 | return cast<llvm::DefInit>(Val: node->getOperator()) |
125 | ->getDef() |
126 | ->getValueAsString(FieldName: "expression" ); |
127 | } |
128 | |
129 | int DagNode::getNumReturnsOfNativeCode() const { |
130 | assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall" ); |
131 | return cast<llvm::DefInit>(Val: node->getOperator()) |
132 | ->getDef() |
133 | ->getValueAsInt(FieldName: "numReturns" ); |
134 | } |
135 | |
136 | llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } |
137 | |
138 | Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { |
139 | llvm::Record *opDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef(); |
140 | auto it = mapper->find(Val: opDef); |
141 | if (it != mapper->end()) |
142 | return *it->second; |
143 | return *mapper->try_emplace(Key: opDef, Args: std::make_unique<Operator>(args&: opDef)) |
144 | .first->second; |
145 | } |
146 | |
147 | int DagNode::getNumOps() const { |
148 | // We want to get number of operations recursively involved in the DAG tree. |
149 | // All other directives should be excluded. |
150 | int count = isOperation() ? 1 : 0; |
151 | for (int i = 0, e = getNumArgs(); i != e; ++i) { |
152 | if (auto child = getArgAsNestedDag(index: i)) |
153 | count += child.getNumOps(); |
154 | } |
155 | return count; |
156 | } |
157 | |
158 | int DagNode::getNumArgs() const { return node->getNumArgs(); } |
159 | |
160 | bool DagNode::isNestedDagArg(unsigned index) const { |
161 | return isa<llvm::DagInit>(Val: node->getArg(Num: index)); |
162 | } |
163 | |
164 | DagNode DagNode::getArgAsNestedDag(unsigned index) const { |
165 | return DagNode(dyn_cast_or_null<llvm::DagInit>(Val: node->getArg(Num: index))); |
166 | } |
167 | |
168 | DagLeaf DagNode::getArgAsLeaf(unsigned index) const { |
169 | assert(!isNestedDagArg(index)); |
170 | return DagLeaf(node->getArg(Num: index)); |
171 | } |
172 | |
173 | StringRef DagNode::getArgName(unsigned index) const { |
174 | return node->getArgNameStr(Num: index); |
175 | } |
176 | |
177 | bool DagNode::isReplaceWithValue() const { |
178 | auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef(); |
179 | return dagOpDef->getName() == "replaceWithValue" ; |
180 | } |
181 | |
182 | bool DagNode::isLocationDirective() const { |
183 | auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef(); |
184 | return dagOpDef->getName() == "location" ; |
185 | } |
186 | |
187 | bool DagNode::isReturnTypeDirective() const { |
188 | auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef(); |
189 | return dagOpDef->getName() == "returnType" ; |
190 | } |
191 | |
192 | bool DagNode::isEither() const { |
193 | auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef(); |
194 | return dagOpDef->getName() == "either" ; |
195 | } |
196 | |
197 | bool DagNode::isVariadic() const { |
198 | auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef(); |
199 | return dagOpDef->getName() == "variadic" ; |
200 | } |
201 | |
202 | void DagNode::print(raw_ostream &os) const { |
203 | if (node) |
204 | node->print(OS&: os); |
205 | } |
206 | |
207 | //===----------------------------------------------------------------------===// |
208 | // SymbolInfoMap |
209 | //===----------------------------------------------------------------------===// |
210 | |
211 | StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { |
212 | int idx = -1; |
213 | auto [name, indexStr] = symbol.rsplit(Separator: "__" ); |
214 | |
215 | if (indexStr.consumeInteger(Radix: 10, Result&: idx)) { |
216 | // The second part is not an index; we return the whole symbol as-is. |
217 | return symbol; |
218 | } |
219 | if (index) { |
220 | *index = idx; |
221 | } |
222 | return name; |
223 | } |
224 | |
225 | SymbolInfoMap::SymbolInfo::SymbolInfo( |
226 | const Operator *op, SymbolInfo::Kind kind, |
227 | std::optional<DagAndConstant> dagAndConstant) |
228 | : op(op), kind(kind), dagAndConstant(dagAndConstant) {} |
229 | |
230 | int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { |
231 | switch (kind) { |
232 | case Kind::Attr: |
233 | case Kind::Operand: |
234 | case Kind::Value: |
235 | return 1; |
236 | case Kind::Result: |
237 | return op->getNumResults(); |
238 | case Kind::MultipleValues: |
239 | return getSize(); |
240 | } |
241 | llvm_unreachable("unknown kind" ); |
242 | } |
243 | |
244 | std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { |
245 | return alternativeName ? *alternativeName : name.str(); |
246 | } |
247 | |
248 | std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { |
249 | LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': " ); |
250 | switch (kind) { |
251 | case Kind::Attr: { |
252 | if (op) |
253 | return op->getArg(index: getArgIndex()) |
254 | .get<NamedAttribute *>() |
255 | ->attr.getStorageType() |
256 | .str(); |
257 | // TODO(suderman): Use a more exact type when available. |
258 | return "::mlir::Attribute" ; |
259 | } |
260 | case Kind::Operand: { |
261 | // Use operand range for captured operands (to support potential variadic |
262 | // operands). |
263 | return "::mlir::Operation::operand_range" ; |
264 | } |
265 | case Kind::Value: { |
266 | return "::mlir::Value" ; |
267 | } |
268 | case Kind::MultipleValues: { |
269 | return "::mlir::ValueRange" ; |
270 | } |
271 | case Kind::Result: { |
272 | // Use the op itself for captured results. |
273 | return op->getQualCppClassName(); |
274 | } |
275 | } |
276 | llvm_unreachable("unknown kind" ); |
277 | } |
278 | |
279 | std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { |
280 | LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': " ); |
281 | std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "" ; |
282 | return std::string( |
283 | formatv(Fmt: "{0} {1}{2};\n" , Vals: getVarTypeStr(name), Vals: getVarName(name), Vals&: varInit)); |
284 | } |
285 | |
286 | std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { |
287 | LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': " ); |
288 | return std::string( |
289 | formatv(Fmt: "{0} &{1}" , Vals: getVarTypeStr(name), Vals: getVarName(name))); |
290 | } |
291 | |
292 | std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( |
293 | StringRef name, int index, const char *fmt, const char *separator) const { |
294 | LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': " ); |
295 | switch (kind) { |
296 | case Kind::Attr: { |
297 | assert(index < 0); |
298 | auto repl = formatv(Fmt: fmt, Vals&: name); |
299 | LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n" ); |
300 | return std::string(repl); |
301 | } |
302 | case Kind::Operand: { |
303 | assert(index < 0); |
304 | auto *operand = op->getArg(index: getArgIndex()).get<NamedTypeConstraint *>(); |
305 | // If this operand is variadic and this SymbolInfo doesn't have a range |
306 | // index, then return the full variadic operand_range. Otherwise, return |
307 | // the value itself. |
308 | if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) { |
309 | auto repl = formatv(Fmt: fmt, Vals&: name); |
310 | LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n" ); |
311 | return std::string(repl); |
312 | } |
313 | auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "(*{0}.begin())" , Vals&: name)); |
314 | LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n" ); |
315 | return std::string(repl); |
316 | } |
317 | case Kind::Result: { |
318 | // If `index` is greater than zero, then we are referencing a specific |
319 | // result of a multi-result op. The result can still be variadic. |
320 | if (index >= 0) { |
321 | std::string v = |
322 | std::string(formatv(Fmt: "{0}.getODSResults({1})" , Vals&: name, Vals&: index)); |
323 | if (!op->getResult(index).isVariadic()) |
324 | v = std::string(formatv(Fmt: "(*{0}.begin())" , Vals&: v)); |
325 | auto repl = formatv(Fmt: fmt, Vals&: v); |
326 | LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n" ); |
327 | return std::string(repl); |
328 | } |
329 | |
330 | // If this op has no result at all but still we bind a symbol to it, it |
331 | // means we want to capture the op itself. |
332 | if (op->getNumResults() == 0) { |
333 | LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n" ); |
334 | return formatv(Fmt: fmt, Vals&: name); |
335 | } |
336 | |
337 | // We are referencing all results of the multi-result op. A specific result |
338 | // can either be a value or a range. Then join them with `separator`. |
339 | SmallVector<std::string, 4> values; |
340 | values.reserve(N: op->getNumResults()); |
341 | |
342 | for (int i = 0, e = op->getNumResults(); i < e; ++i) { |
343 | std::string v = std::string(formatv(Fmt: "{0}.getODSResults({1})" , Vals&: name, Vals&: i)); |
344 | if (!op->getResult(index: i).isVariadic()) { |
345 | v = std::string(formatv(Fmt: "(*{0}.begin())" , Vals&: v)); |
346 | } |
347 | values.push_back(Elt: std::string(formatv(Fmt: fmt, Vals&: v))); |
348 | } |
349 | auto repl = llvm::join(R&: values, Separator: separator); |
350 | LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n" ); |
351 | return repl; |
352 | } |
353 | case Kind::Value: { |
354 | assert(index < 0); |
355 | assert(op == nullptr); |
356 | auto repl = formatv(Fmt: fmt, Vals&: name); |
357 | LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n" ); |
358 | return std::string(repl); |
359 | } |
360 | case Kind::MultipleValues: { |
361 | assert(op == nullptr); |
362 | assert(index < getSize()); |
363 | if (index >= 0) { |
364 | std::string repl = |
365 | formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]" , Vals&: name, Vals&: index))); |
366 | LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n" ); |
367 | return repl; |
368 | } |
369 | // If it doesn't specify certain element, unpack them all. |
370 | auto repl = |
371 | formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()" , Vals&: name))); |
372 | LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n" ); |
373 | return std::string(repl); |
374 | } |
375 | } |
376 | llvm_unreachable("unknown kind" ); |
377 | } |
378 | |
379 | std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( |
380 | StringRef name, int index, const char *fmt, const char *separator) const { |
381 | LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': " ); |
382 | switch (kind) { |
383 | case Kind::Attr: |
384 | case Kind::Operand: { |
385 | assert(index < 0 && "only allowed for symbol bound to result" ); |
386 | auto repl = formatv(Fmt: fmt, Vals&: name); |
387 | LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n" ); |
388 | return std::string(repl); |
389 | } |
390 | case Kind::Result: { |
391 | if (index >= 0) { |
392 | auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})" , Vals&: name, Vals&: index)); |
393 | LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n" ); |
394 | return std::string(repl); |
395 | } |
396 | |
397 | // We are referencing all results of the multi-result op. Each result should |
398 | // have a value range, and then join them with `separator`. |
399 | SmallVector<std::string, 4> values; |
400 | values.reserve(N: op->getNumResults()); |
401 | |
402 | for (int i = 0, e = op->getNumResults(); i < e; ++i) { |
403 | values.push_back(Elt: std::string( |
404 | formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})" , Vals&: name, Vals&: i)))); |
405 | } |
406 | auto repl = llvm::join(R&: values, Separator: separator); |
407 | LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n" ); |
408 | return repl; |
409 | } |
410 | case Kind::Value: { |
411 | assert(index < 0 && "only allowed for symbol bound to result" ); |
412 | assert(op == nullptr); |
413 | auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{{{0}}" , Vals&: name)); |
414 | LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n" ); |
415 | return std::string(repl); |
416 | } |
417 | case Kind::MultipleValues: { |
418 | assert(op == nullptr); |
419 | assert(index < getSize()); |
420 | if (index >= 0) { |
421 | std::string repl = |
422 | formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]" , Vals&: name, Vals&: index))); |
423 | LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n" ); |
424 | return repl; |
425 | } |
426 | auto repl = |
427 | formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()" , Vals&: name))); |
428 | LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n" ); |
429 | return std::string(repl); |
430 | } |
431 | } |
432 | llvm_unreachable("unknown kind" ); |
433 | } |
434 | |
435 | bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, |
436 | const Operator &op, int argIndex, |
437 | std::optional<int> variadicSubIndex) { |
438 | StringRef name = getValuePackName(symbol); |
439 | if (name != symbol) { |
440 | auto error = formatv( |
441 | Fmt: "symbol '{0}' with trailing index cannot bind to op argument" , Vals&: symbol); |
442 | PrintFatalError(ErrorLoc: loc, Msg: error); |
443 | } |
444 | |
445 | auto symInfo = |
446 | op.getArg(index: argIndex).is<NamedAttribute *>() |
447 | ? SymbolInfo::getAttr(op: &op, index: argIndex) |
448 | : SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex); |
449 | |
450 | std::string key = symbol.str(); |
451 | if (symbolInfoMap.count(x: key)) { |
452 | // Only non unique name for the operand is supported. |
453 | if (symInfo.kind != SymbolInfo::Kind::Operand) { |
454 | return false; |
455 | } |
456 | |
457 | // Cannot add new operand if there is already non operand with the same |
458 | // name. |
459 | if (symbolInfoMap.find(x: key)->second.kind != SymbolInfo::Kind::Operand) { |
460 | return false; |
461 | } |
462 | } |
463 | |
464 | symbolInfoMap.emplace(args&: key, args&: symInfo); |
465 | return true; |
466 | } |
467 | |
468 | bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { |
469 | std::string name = getValuePackName(symbol).str(); |
470 | auto inserted = symbolInfoMap.emplace(args&: name, args: SymbolInfo::getResult(op: &op)); |
471 | |
472 | return symbolInfoMap.count(x: inserted->first) == 1; |
473 | } |
474 | |
475 | bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { |
476 | std::string name = getValuePackName(symbol).str(); |
477 | if (numValues > 1) |
478 | return bindMultipleValues(symbol: name, numValues); |
479 | return bindValue(symbol: name); |
480 | } |
481 | |
482 | bool SymbolInfoMap::bindValue(StringRef symbol) { |
483 | auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getValue()); |
484 | return symbolInfoMap.count(x: inserted->first) == 1; |
485 | } |
486 | |
487 | bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { |
488 | std::string name = getValuePackName(symbol).str(); |
489 | auto inserted = |
490 | symbolInfoMap.emplace(args&: name, args: SymbolInfo::getMultipleValues(numValues)); |
491 | return symbolInfoMap.count(x: inserted->first) == 1; |
492 | } |
493 | |
494 | bool SymbolInfoMap::bindAttr(StringRef symbol) { |
495 | auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getAttr()); |
496 | return symbolInfoMap.count(x: inserted->first) == 1; |
497 | } |
498 | |
499 | bool SymbolInfoMap::contains(StringRef symbol) const { |
500 | return find(key: symbol) != symbolInfoMap.end(); |
501 | } |
502 | |
503 | SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { |
504 | std::string name = getValuePackName(symbol: key).str(); |
505 | |
506 | return symbolInfoMap.find(x: name); |
507 | } |
508 | |
509 | SymbolInfoMap::const_iterator |
510 | SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, |
511 | int argIndex, |
512 | std::optional<int> variadicSubIndex) const { |
513 | return findBoundSymbol( |
514 | key, symbolInfo: SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex)); |
515 | } |
516 | |
517 | SymbolInfoMap::const_iterator |
518 | SymbolInfoMap::findBoundSymbol(StringRef key, |
519 | const SymbolInfo &symbolInfo) const { |
520 | std::string name = getValuePackName(symbol: key).str(); |
521 | auto range = symbolInfoMap.equal_range(x: name); |
522 | |
523 | for (auto it = range.first; it != range.second; ++it) |
524 | if (it->second.dagAndConstant == symbolInfo.dagAndConstant) |
525 | return it; |
526 | |
527 | return symbolInfoMap.end(); |
528 | } |
529 | |
530 | std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> |
531 | SymbolInfoMap::getRangeOfEqualElements(StringRef key) { |
532 | std::string name = getValuePackName(symbol: key).str(); |
533 | |
534 | return symbolInfoMap.equal_range(x: name); |
535 | } |
536 | |
537 | int SymbolInfoMap::count(StringRef key) const { |
538 | std::string name = getValuePackName(symbol: key).str(); |
539 | return symbolInfoMap.count(x: name); |
540 | } |
541 | |
542 | int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { |
543 | StringRef name = getValuePackName(symbol); |
544 | if (name != symbol) { |
545 | // If there is a trailing index inside symbol, it references just one |
546 | // static value. |
547 | return 1; |
548 | } |
549 | // Otherwise, find how many it represents by querying the symbol's info. |
550 | return find(key: name)->second.getStaticValueCount(); |
551 | } |
552 | |
553 | std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, |
554 | const char *fmt, |
555 | const char *separator) const { |
556 | int index = -1; |
557 | StringRef name = getValuePackName(symbol, index: &index); |
558 | |
559 | auto it = symbolInfoMap.find(x: name.str()); |
560 | if (it == symbolInfoMap.end()) { |
561 | auto error = formatv(Fmt: "referencing unbound symbol '{0}'" , Vals&: symbol); |
562 | PrintFatalError(ErrorLoc: loc, Msg: error); |
563 | } |
564 | |
565 | return it->second.getValueAndRangeUse(name, index, fmt, separator); |
566 | } |
567 | |
568 | std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, |
569 | const char *separator) const { |
570 | int index = -1; |
571 | StringRef name = getValuePackName(symbol, index: &index); |
572 | |
573 | auto it = symbolInfoMap.find(x: name.str()); |
574 | if (it == symbolInfoMap.end()) { |
575 | auto error = formatv(Fmt: "referencing unbound symbol '{0}'" , Vals&: symbol); |
576 | PrintFatalError(ErrorLoc: loc, Msg: error); |
577 | } |
578 | |
579 | return it->second.getAllRangeUse(name, index, fmt, separator); |
580 | } |
581 | |
582 | void SymbolInfoMap::assignUniqueAlternativeNames() { |
583 | llvm::StringSet<> usedNames; |
584 | |
585 | for (auto symbolInfoIt = symbolInfoMap.begin(); |
586 | symbolInfoIt != symbolInfoMap.end();) { |
587 | auto range = symbolInfoMap.equal_range(x: symbolInfoIt->first); |
588 | auto startRange = range.first; |
589 | auto endRange = range.second; |
590 | |
591 | auto operandName = symbolInfoIt->first; |
592 | int startSearchIndex = 0; |
593 | for (++startRange; startRange != endRange; ++startRange) { |
594 | // Current operand name is not unique, find a unique one |
595 | // and set the alternative name. |
596 | for (int i = startSearchIndex;; ++i) { |
597 | std::string alternativeName = operandName + std::to_string(val: i); |
598 | if (!usedNames.contains(key: alternativeName) && |
599 | symbolInfoMap.count(x: alternativeName) == 0) { |
600 | usedNames.insert(key: alternativeName); |
601 | startRange->second.alternativeName = alternativeName; |
602 | startSearchIndex = i + 1; |
603 | |
604 | break; |
605 | } |
606 | } |
607 | } |
608 | |
609 | symbolInfoIt = endRange; |
610 | } |
611 | } |
612 | |
613 | //===----------------------------------------------------------------------===// |
614 | // Pattern |
615 | //==----------------------------------------------------------------------===// |
616 | |
617 | Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) |
618 | : def(*def), recordOpMap(mapper) {} |
619 | |
620 | DagNode Pattern::getSourcePattern() const { |
621 | return DagNode(def.getValueAsDag(FieldName: "sourcePattern" )); |
622 | } |
623 | |
624 | int Pattern::getNumResultPatterns() const { |
625 | auto *results = def.getValueAsListInit(FieldName: "resultPatterns" ); |
626 | return results->size(); |
627 | } |
628 | |
629 | DagNode Pattern::getResultPattern(unsigned index) const { |
630 | auto *results = def.getValueAsListInit(FieldName: "resultPatterns" ); |
631 | return DagNode(cast<llvm::DagInit>(Val: results->getElement(i: index))); |
632 | } |
633 | |
634 | void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { |
635 | LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n" ); |
636 | collectBoundSymbols(tree: getSourcePattern(), infoMap, /*isSrcPattern=*/true); |
637 | LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n" ); |
638 | |
639 | LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n" ); |
640 | infoMap.assignUniqueAlternativeNames(); |
641 | LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n" ); |
642 | } |
643 | |
644 | void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { |
645 | LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n" ); |
646 | for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { |
647 | auto pattern = getResultPattern(index: i); |
648 | collectBoundSymbols(tree: pattern, infoMap, /*isSrcPattern=*/false); |
649 | } |
650 | LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n" ); |
651 | } |
652 | |
653 | const Operator &Pattern::getSourceRootOp() { |
654 | return getSourcePattern().getDialectOp(mapper: recordOpMap); |
655 | } |
656 | |
657 | Operator &Pattern::getDialectOp(DagNode node) { |
658 | return node.getDialectOp(mapper: recordOpMap); |
659 | } |
660 | |
661 | std::vector<AppliedConstraint> Pattern::getConstraints() const { |
662 | auto *listInit = def.getValueAsListInit(FieldName: "constraints" ); |
663 | std::vector<AppliedConstraint> ret; |
664 | ret.reserve(n: listInit->size()); |
665 | |
666 | for (auto *it : *listInit) { |
667 | auto *dagInit = dyn_cast<llvm::DagInit>(Val: it); |
668 | if (!dagInit) |
669 | PrintFatalError(Rec: &def, Msg: "all elements in Pattern multi-entity " |
670 | "constraints should be DAG nodes" ); |
671 | |
672 | std::vector<std::string> entities; |
673 | entities.reserve(n: dagInit->arg_size()); |
674 | for (auto *argName : dagInit->getArgNames()) { |
675 | if (!argName) { |
676 | PrintFatalError( |
677 | Rec: &def, |
678 | Msg: "operands to additional constraints can only be symbol references" ); |
679 | } |
680 | entities.emplace_back(args: argName->getValue()); |
681 | } |
682 | |
683 | ret.emplace_back(args: cast<llvm::DefInit>(Val: dagInit->getOperator())->getDef(), |
684 | args: dagInit->getNameStr(), args: std::move(entities)); |
685 | } |
686 | return ret; |
687 | } |
688 | |
689 | int Pattern::getNumSupplementalPatterns() const { |
690 | auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns" ); |
691 | return results->size(); |
692 | } |
693 | |
694 | DagNode Pattern::getSupplementalPattern(unsigned index) const { |
695 | auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns" ); |
696 | return DagNode(cast<llvm::DagInit>(Val: results->getElement(i: index))); |
697 | } |
698 | |
699 | int Pattern::getBenefit() const { |
700 | // The initial benefit value is a heuristic with number of ops in the source |
701 | // pattern. |
702 | int initBenefit = getSourcePattern().getNumOps(); |
703 | llvm::DagInit *delta = def.getValueAsDag(FieldName: "benefitDelta" ); |
704 | if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(Val: delta->getArg(Num: 0))) { |
705 | PrintFatalError(Rec: &def, |
706 | Msg: "The 'addBenefit' takes and only takes one integer value" ); |
707 | } |
708 | return initBenefit + dyn_cast<llvm::IntInit>(Val: delta->getArg(Num: 0))->getValue(); |
709 | } |
710 | |
711 | std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { |
712 | std::vector<std::pair<StringRef, unsigned>> result; |
713 | result.reserve(n: def.getLoc().size()); |
714 | for (auto loc : def.getLoc()) { |
715 | unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(Loc: loc); |
716 | assert(buf && "invalid source location" ); |
717 | result.emplace_back( |
718 | args: llvm::SrcMgr.getBufferInfo(i: buf).Buffer->getBufferIdentifier(), |
719 | args: llvm::SrcMgr.getLineAndColumn(Loc: loc, BufferID: buf).first); |
720 | } |
721 | return result; |
722 | } |
723 | |
724 | void Pattern::verifyBind(bool result, StringRef symbolName) { |
725 | if (!result) { |
726 | auto err = formatv(Fmt: "symbol '{0}' bound more than once" , Vals&: symbolName); |
727 | PrintFatalError(Rec: &def, Msg: err); |
728 | } |
729 | } |
730 | |
731 | void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, |
732 | bool isSrcPattern) { |
733 | auto treeName = tree.getSymbol(); |
734 | auto numTreeArgs = tree.getNumArgs(); |
735 | |
736 | if (tree.isNativeCodeCall()) { |
737 | if (!treeName.empty()) { |
738 | if (!isSrcPattern) { |
739 | LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " |
740 | << treeName << '\n'); |
741 | verifyBind( |
742 | result: infoMap.bindValues(symbol: treeName, numValues: tree.getNumReturnsOfNativeCode()), |
743 | symbolName: treeName); |
744 | } else { |
745 | PrintFatalError(Rec: &def, |
746 | Msg: formatv(Fmt: "binding symbol '{0}' to NativecodeCall in " |
747 | "MatchPattern is not supported" , |
748 | Vals&: treeName)); |
749 | } |
750 | } |
751 | |
752 | for (int i = 0; i != numTreeArgs; ++i) { |
753 | if (auto treeArg = tree.getArgAsNestedDag(index: i)) { |
754 | // This DAG node argument is a DAG node itself. Go inside recursively. |
755 | collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern); |
756 | continue; |
757 | } |
758 | |
759 | if (!isSrcPattern) |
760 | continue; |
761 | |
762 | // We can only bind symbols to arguments in source pattern. Those |
763 | // symbols are referenced in result patterns. |
764 | auto treeArgName = tree.getArgName(index: i); |
765 | |
766 | // `$_` is a special symbol meaning ignore the current argument. |
767 | if (!treeArgName.empty() && treeArgName != "_" ) { |
768 | DagLeaf leaf = tree.getArgAsLeaf(index: i); |
769 | |
770 | // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), |
771 | if (leaf.isUnspecified()) { |
772 | // This is case of $c, a Value without any constraints. |
773 | verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName); |
774 | } else { |
775 | auto constraint = leaf.getAsConstraint(); |
776 | bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || |
777 | leaf.isConstantAttr() || |
778 | constraint.getKind() == Constraint::Kind::CK_Attr; |
779 | |
780 | if (isAttr) { |
781 | // This is case of $a, a binding to a certain attribute. |
782 | verifyBind(result: infoMap.bindAttr(symbol: treeArgName), symbolName: treeArgName); |
783 | continue; |
784 | } |
785 | |
786 | // This is case of $b, a binding to a certain type. |
787 | verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName); |
788 | } |
789 | } |
790 | } |
791 | |
792 | return; |
793 | } |
794 | |
795 | if (tree.isOperation()) { |
796 | auto &op = getDialectOp(node: tree); |
797 | auto numOpArgs = op.getNumArgs(); |
798 | int numEither = 0; |
799 | |
800 | // We need to exclude the trailing directives and `either` directive groups |
801 | // two operands of the operation. |
802 | int numDirectives = 0; |
803 | for (int i = numTreeArgs - 1; i >= 0; --i) { |
804 | if (auto dagArg = tree.getArgAsNestedDag(index: i)) { |
805 | if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) |
806 | ++numDirectives; |
807 | else if (dagArg.isEither()) |
808 | ++numEither; |
809 | } |
810 | } |
811 | |
812 | if (numOpArgs != numTreeArgs - numDirectives + numEither) { |
813 | auto err = |
814 | formatv(Fmt: "op '{0}' argument number mismatch: " |
815 | "{1} in pattern vs. {2} in definition" , |
816 | Vals: op.getOperationName(), Vals: numTreeArgs + numEither, Vals&: numOpArgs); |
817 | PrintFatalError(Rec: &def, Msg: err); |
818 | } |
819 | |
820 | // The name attached to the DAG node's operator is for representing the |
821 | // results generated from this op. It should be remembered as bound results. |
822 | if (!treeName.empty()) { |
823 | LLVM_DEBUG(llvm::dbgs() |
824 | << "found symbol bound to op result: " << treeName << '\n'); |
825 | verifyBind(result: infoMap.bindOpResult(symbol: treeName, op), symbolName: treeName); |
826 | } |
827 | |
828 | // The operand in `either` DAG should be bound to the operation in the |
829 | // parent DagNode. |
830 | auto collectSymbolInEither = [&](DagNode parent, DagNode tree, |
831 | int opArgIdx) { |
832 | for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) { |
833 | if (DagNode subTree = tree.getArgAsNestedDag(index: i)) { |
834 | collectBoundSymbols(tree: subTree, infoMap, isSrcPattern); |
835 | } else { |
836 | auto argName = tree.getArgName(index: i); |
837 | if (!argName.empty() && argName != "_" ) { |
838 | verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx), |
839 | symbolName: argName); |
840 | } |
841 | } |
842 | } |
843 | }; |
844 | |
845 | // The operand in `variadic` DAG should be bound to the operation in the |
846 | // parent DagNode. The range index must be included as well to distinguish |
847 | // (potentially) repeating argName within the `variadic` DAG. |
848 | auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree, |
849 | int opArgIdx) { |
850 | auto treeName = tree.getSymbol(); |
851 | if (!treeName.empty()) { |
852 | // If treeName is specified, bind to the full variadic operand_range. |
853 | verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: treeName, op, argIndex: opArgIdx, |
854 | variadicSubIndex: std::nullopt), |
855 | symbolName: treeName); |
856 | } |
857 | |
858 | for (int i = 0; i < tree.getNumArgs(); ++i) { |
859 | if (DagNode subTree = tree.getArgAsNestedDag(index: i)) { |
860 | collectBoundSymbols(tree: subTree, infoMap, isSrcPattern); |
861 | } else { |
862 | auto argName = tree.getArgName(index: i); |
863 | if (!argName.empty() && argName != "_" ) { |
864 | verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx, |
865 | /*variadicSubIndex=*/i), |
866 | symbolName: argName); |
867 | } |
868 | } |
869 | } |
870 | }; |
871 | |
872 | for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) { |
873 | if (auto treeArg = tree.getArgAsNestedDag(index: i)) { |
874 | if (treeArg.isEither()) { |
875 | collectSymbolInEither(tree, treeArg, opArgIdx); |
876 | // `either` DAG is *flattened*. For example, |
877 | // |
878 | // (FooOp (either arg0, arg1), arg2) |
879 | // |
880 | // can be viewed as: |
881 | // |
882 | // (FooOp arg0, arg1, arg2) |
883 | ++opArgIdx; |
884 | } else if (treeArg.isVariadic()) { |
885 | collectSymbolInVariadic(tree, treeArg, opArgIdx); |
886 | } else { |
887 | // This DAG node argument is a DAG node itself. Go inside recursively. |
888 | collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern); |
889 | } |
890 | continue; |
891 | } |
892 | |
893 | if (isSrcPattern) { |
894 | // We can only bind symbols to op arguments in source pattern. Those |
895 | // symbols are referenced in result patterns. |
896 | auto treeArgName = tree.getArgName(index: i); |
897 | // `$_` is a special symbol meaning ignore the current argument. |
898 | if (!treeArgName.empty() && treeArgName != "_" ) { |
899 | LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " |
900 | << treeArgName << '\n'); |
901 | verifyBind(result: infoMap.bindOpArgument(node: tree, symbol: treeArgName, op, argIndex: opArgIdx), |
902 | symbolName: treeArgName); |
903 | } |
904 | } |
905 | } |
906 | return; |
907 | } |
908 | |
909 | if (!treeName.empty()) { |
910 | PrintFatalError( |
911 | Rec: &def, Msg: formatv(Fmt: "binding symbol '{0}' to non-operation/native code call " |
912 | "unsupported right now" , |
913 | Vals&: treeName)); |
914 | } |
915 | } |
916 | |