1//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pattern wrapper class to simplify using TableGen Record defining a MLIR
10// Pattern.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
16#include "mlir/TableGen/Pattern.h"
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/TableGen/Error.h"
22#include "llvm/TableGen/Record.h"
23
24#define DEBUG_TYPE "mlir-tblgen-pattern"
25
26using namespace mlir;
27using namespace tblgen;
28
29using llvm::formatv;
30
31//===----------------------------------------------------------------------===//
32// DagLeaf
33//===----------------------------------------------------------------------===//
34
35bool DagLeaf::isUnspecified() const {
36 return isa_and_nonnull<llvm::UnsetInit>(Val: def);
37}
38
39bool DagLeaf::isOperandMatcher() const {
40 // Operand matchers specify a type constraint.
41 return isSubClassOf(superclass: "TypeConstraint");
42}
43
44bool DagLeaf::isAttrMatcher() const {
45 // Attribute matchers specify an attribute constraint.
46 return isSubClassOf(superclass: "AttrConstraint");
47}
48
49bool DagLeaf::isNativeCodeCall() const {
50 return isSubClassOf(superclass: "NativeCodeCall");
51}
52
53bool DagLeaf::isConstantAttr() const { return isSubClassOf(superclass: "ConstantAttr"); }
54
55bool DagLeaf::isEnumAttrCase() const {
56 return isSubClassOf(superclass: "EnumAttrCaseInfo");
57}
58
59bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(Val: def); }
60
61Constraint DagLeaf::getAsConstraint() const {
62 assert((isOperandMatcher() || isAttrMatcher()) &&
63 "the DAG leaf must be operand or attribute");
64 return Constraint(cast<llvm::DefInit>(Val: def)->getDef());
65}
66
67ConstantAttr DagLeaf::getAsConstantAttr() const {
68 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69 return ConstantAttr(cast<llvm::DefInit>(Val: def));
70}
71
72EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74 return EnumAttrCase(cast<llvm::DefInit>(Val: def));
75}
76
77std::string DagLeaf::getConditionTemplate() const {
78 return getAsConstraint().getConditionTemplate();
79}
80
81llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83 return cast<llvm::DefInit>(Val: def)->getDef()->getValueAsString(FieldName: "expression");
84}
85
86int DagLeaf::getNumReturnsOfNativeCode() const {
87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88 return cast<llvm::DefInit>(Val: def)->getDef()->getValueAsInt(FieldName: "numReturns");
89}
90
91std::string DagLeaf::getStringAttr() const {
92 assert(isStringAttr() && "the DAG leaf must be string attribute");
93 return def->getAsUnquotedString();
94}
95bool DagLeaf::isSubClassOf(StringRef superclass) const {
96 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(Val: def))
97 return defInit->getDef()->isSubClassOf(Name: superclass);
98 return false;
99}
100
101void DagLeaf::print(raw_ostream &os) const {
102 if (def)
103 def->print(OS&: os);
104}
105
106//===----------------------------------------------------------------------===//
107// DagNode
108//===----------------------------------------------------------------------===//
109
110bool DagNode::isNativeCodeCall() const {
111 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(Val: node->getOperator()))
112 return defInit->getDef()->isSubClassOf(Name: "NativeCodeCall");
113 return false;
114}
115
116bool DagNode::isOperation() const {
117 return !isNativeCodeCall() && !isReplaceWithValue() &&
118 !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
119 !isVariadic();
120}
121
122llvm::StringRef DagNode::getNativeCodeTemplate() const {
123 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
124 return cast<llvm::DefInit>(Val: node->getOperator())
125 ->getDef()
126 ->getValueAsString(FieldName: "expression");
127}
128
129int DagNode::getNumReturnsOfNativeCode() const {
130 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
131 return cast<llvm::DefInit>(Val: node->getOperator())
132 ->getDef()
133 ->getValueAsInt(FieldName: "numReturns");
134}
135
136llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
137
138Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
139 llvm::Record *opDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef();
140 auto it = mapper->find(Val: opDef);
141 if (it != mapper->end())
142 return *it->second;
143 return *mapper->try_emplace(Key: opDef, Args: std::make_unique<Operator>(args&: opDef))
144 .first->second;
145}
146
147int DagNode::getNumOps() const {
148 // We want to get number of operations recursively involved in the DAG tree.
149 // All other directives should be excluded.
150 int count = isOperation() ? 1 : 0;
151 for (int i = 0, e = getNumArgs(); i != e; ++i) {
152 if (auto child = getArgAsNestedDag(index: i))
153 count += child.getNumOps();
154 }
155 return count;
156}
157
158int DagNode::getNumArgs() const { return node->getNumArgs(); }
159
160bool DagNode::isNestedDagArg(unsigned index) const {
161 return isa<llvm::DagInit>(Val: node->getArg(Num: index));
162}
163
164DagNode DagNode::getArgAsNestedDag(unsigned index) const {
165 return DagNode(dyn_cast_or_null<llvm::DagInit>(Val: node->getArg(Num: index)));
166}
167
168DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
169 assert(!isNestedDagArg(index));
170 return DagLeaf(node->getArg(Num: index));
171}
172
173StringRef DagNode::getArgName(unsigned index) const {
174 return node->getArgNameStr(Num: index);
175}
176
177bool DagNode::isReplaceWithValue() const {
178 auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef();
179 return dagOpDef->getName() == "replaceWithValue";
180}
181
182bool DagNode::isLocationDirective() const {
183 auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef();
184 return dagOpDef->getName() == "location";
185}
186
187bool DagNode::isReturnTypeDirective() const {
188 auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef();
189 return dagOpDef->getName() == "returnType";
190}
191
192bool DagNode::isEither() const {
193 auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef();
194 return dagOpDef->getName() == "either";
195}
196
197bool DagNode::isVariadic() const {
198 auto *dagOpDef = cast<llvm::DefInit>(Val: node->getOperator())->getDef();
199 return dagOpDef->getName() == "variadic";
200}
201
202void DagNode::print(raw_ostream &os) const {
203 if (node)
204 node->print(OS&: os);
205}
206
207//===----------------------------------------------------------------------===//
208// SymbolInfoMap
209//===----------------------------------------------------------------------===//
210
211StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
212 int idx = -1;
213 auto [name, indexStr] = symbol.rsplit(Separator: "__");
214
215 if (indexStr.consumeInteger(Radix: 10, Result&: idx)) {
216 // The second part is not an index; we return the whole symbol as-is.
217 return symbol;
218 }
219 if (index) {
220 *index = idx;
221 }
222 return name;
223}
224
225SymbolInfoMap::SymbolInfo::SymbolInfo(
226 const Operator *op, SymbolInfo::Kind kind,
227 std::optional<DagAndConstant> dagAndConstant)
228 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
229
230int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
231 switch (kind) {
232 case Kind::Attr:
233 case Kind::Operand:
234 case Kind::Value:
235 return 1;
236 case Kind::Result:
237 return op->getNumResults();
238 case Kind::MultipleValues:
239 return getSize();
240 }
241 llvm_unreachable("unknown kind");
242}
243
244std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
245 return alternativeName ? *alternativeName : name.str();
246}
247
248std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
249 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
250 switch (kind) {
251 case Kind::Attr: {
252 if (op)
253 return op->getArg(index: getArgIndex())
254 .get<NamedAttribute *>()
255 ->attr.getStorageType()
256 .str();
257 // TODO(suderman): Use a more exact type when available.
258 return "::mlir::Attribute";
259 }
260 case Kind::Operand: {
261 // Use operand range for captured operands (to support potential variadic
262 // operands).
263 return "::mlir::Operation::operand_range";
264 }
265 case Kind::Value: {
266 return "::mlir::Value";
267 }
268 case Kind::MultipleValues: {
269 return "::mlir::ValueRange";
270 }
271 case Kind::Result: {
272 // Use the op itself for captured results.
273 return op->getQualCppClassName();
274 }
275 }
276 llvm_unreachable("unknown kind");
277}
278
279std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
280 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
281 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
282 return std::string(
283 formatv(Fmt: "{0} {1}{2};\n", Vals: getVarTypeStr(name), Vals: getVarName(name), Vals&: varInit));
284}
285
286std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
287 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
288 return std::string(
289 formatv(Fmt: "{0} &{1}", Vals: getVarTypeStr(name), Vals: getVarName(name)));
290}
291
292std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
293 StringRef name, int index, const char *fmt, const char *separator) const {
294 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
295 switch (kind) {
296 case Kind::Attr: {
297 assert(index < 0);
298 auto repl = formatv(Fmt: fmt, Vals&: name);
299 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
300 return std::string(repl);
301 }
302 case Kind::Operand: {
303 assert(index < 0);
304 auto *operand = op->getArg(index: getArgIndex()).get<NamedTypeConstraint *>();
305 // If this operand is variadic and this SymbolInfo doesn't have a range
306 // index, then return the full variadic operand_range. Otherwise, return
307 // the value itself.
308 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
309 auto repl = formatv(Fmt: fmt, Vals&: name);
310 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
311 return std::string(repl);
312 }
313 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "(*{0}.begin())", Vals&: name));
314 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
315 return std::string(repl);
316 }
317 case Kind::Result: {
318 // If `index` is greater than zero, then we are referencing a specific
319 // result of a multi-result op. The result can still be variadic.
320 if (index >= 0) {
321 std::string v =
322 std::string(formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: index));
323 if (!op->getResult(index).isVariadic())
324 v = std::string(formatv(Fmt: "(*{0}.begin())", Vals&: v));
325 auto repl = formatv(Fmt: fmt, Vals&: v);
326 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
327 return std::string(repl);
328 }
329
330 // If this op has no result at all but still we bind a symbol to it, it
331 // means we want to capture the op itself.
332 if (op->getNumResults() == 0) {
333 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
334 return formatv(Fmt: fmt, Vals&: name);
335 }
336
337 // We are referencing all results of the multi-result op. A specific result
338 // can either be a value or a range. Then join them with `separator`.
339 SmallVector<std::string, 4> values;
340 values.reserve(N: op->getNumResults());
341
342 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
343 std::string v = std::string(formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: i));
344 if (!op->getResult(index: i).isVariadic()) {
345 v = std::string(formatv(Fmt: "(*{0}.begin())", Vals&: v));
346 }
347 values.push_back(Elt: std::string(formatv(Fmt: fmt, Vals&: v)));
348 }
349 auto repl = llvm::join(R&: values, Separator: separator);
350 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
351 return repl;
352 }
353 case Kind::Value: {
354 assert(index < 0);
355 assert(op == nullptr);
356 auto repl = formatv(Fmt: fmt, Vals&: name);
357 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
358 return std::string(repl);
359 }
360 case Kind::MultipleValues: {
361 assert(op == nullptr);
362 assert(index < getSize());
363 if (index >= 0) {
364 std::string repl =
365 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]", Vals&: name, Vals&: index)));
366 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
367 return repl;
368 }
369 // If it doesn't specify certain element, unpack them all.
370 auto repl =
371 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()", Vals&: name)));
372 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
373 return std::string(repl);
374 }
375 }
376 llvm_unreachable("unknown kind");
377}
378
379std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
380 StringRef name, int index, const char *fmt, const char *separator) const {
381 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
382 switch (kind) {
383 case Kind::Attr:
384 case Kind::Operand: {
385 assert(index < 0 && "only allowed for symbol bound to result");
386 auto repl = formatv(Fmt: fmt, Vals&: name);
387 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
388 return std::string(repl);
389 }
390 case Kind::Result: {
391 if (index >= 0) {
392 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: index));
393 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
394 return std::string(repl);
395 }
396
397 // We are referencing all results of the multi-result op. Each result should
398 // have a value range, and then join them with `separator`.
399 SmallVector<std::string, 4> values;
400 values.reserve(N: op->getNumResults());
401
402 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
403 values.push_back(Elt: std::string(
404 formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: i))));
405 }
406 auto repl = llvm::join(R&: values, Separator: separator);
407 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
408 return repl;
409 }
410 case Kind::Value: {
411 assert(index < 0 && "only allowed for symbol bound to result");
412 assert(op == nullptr);
413 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{{{0}}", Vals&: name));
414 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
415 return std::string(repl);
416 }
417 case Kind::MultipleValues: {
418 assert(op == nullptr);
419 assert(index < getSize());
420 if (index >= 0) {
421 std::string repl =
422 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]", Vals&: name, Vals&: index)));
423 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
424 return repl;
425 }
426 auto repl =
427 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()", Vals&: name)));
428 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
429 return std::string(repl);
430 }
431 }
432 llvm_unreachable("unknown kind");
433}
434
435bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
436 const Operator &op, int argIndex,
437 std::optional<int> variadicSubIndex) {
438 StringRef name = getValuePackName(symbol);
439 if (name != symbol) {
440 auto error = formatv(
441 Fmt: "symbol '{0}' with trailing index cannot bind to op argument", Vals&: symbol);
442 PrintFatalError(ErrorLoc: loc, Msg: error);
443 }
444
445 auto symInfo =
446 op.getArg(index: argIndex).is<NamedAttribute *>()
447 ? SymbolInfo::getAttr(op: &op, index: argIndex)
448 : SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex);
449
450 std::string key = symbol.str();
451 if (symbolInfoMap.count(x: key)) {
452 // Only non unique name for the operand is supported.
453 if (symInfo.kind != SymbolInfo::Kind::Operand) {
454 return false;
455 }
456
457 // Cannot add new operand if there is already non operand with the same
458 // name.
459 if (symbolInfoMap.find(x: key)->second.kind != SymbolInfo::Kind::Operand) {
460 return false;
461 }
462 }
463
464 symbolInfoMap.emplace(args&: key, args&: symInfo);
465 return true;
466}
467
468bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
469 std::string name = getValuePackName(symbol).str();
470 auto inserted = symbolInfoMap.emplace(args&: name, args: SymbolInfo::getResult(op: &op));
471
472 return symbolInfoMap.count(x: inserted->first) == 1;
473}
474
475bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
476 std::string name = getValuePackName(symbol).str();
477 if (numValues > 1)
478 return bindMultipleValues(symbol: name, numValues);
479 return bindValue(symbol: name);
480}
481
482bool SymbolInfoMap::bindValue(StringRef symbol) {
483 auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getValue());
484 return symbolInfoMap.count(x: inserted->first) == 1;
485}
486
487bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
488 std::string name = getValuePackName(symbol).str();
489 auto inserted =
490 symbolInfoMap.emplace(args&: name, args: SymbolInfo::getMultipleValues(numValues));
491 return symbolInfoMap.count(x: inserted->first) == 1;
492}
493
494bool SymbolInfoMap::bindAttr(StringRef symbol) {
495 auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getAttr());
496 return symbolInfoMap.count(x: inserted->first) == 1;
497}
498
499bool SymbolInfoMap::contains(StringRef symbol) const {
500 return find(key: symbol) != symbolInfoMap.end();
501}
502
503SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
504 std::string name = getValuePackName(symbol: key).str();
505
506 return symbolInfoMap.find(x: name);
507}
508
509SymbolInfoMap::const_iterator
510SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
511 int argIndex,
512 std::optional<int> variadicSubIndex) const {
513 return findBoundSymbol(
514 key, symbolInfo: SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex));
515}
516
517SymbolInfoMap::const_iterator
518SymbolInfoMap::findBoundSymbol(StringRef key,
519 const SymbolInfo &symbolInfo) const {
520 std::string name = getValuePackName(symbol: key).str();
521 auto range = symbolInfoMap.equal_range(x: name);
522
523 for (auto it = range.first; it != range.second; ++it)
524 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
525 return it;
526
527 return symbolInfoMap.end();
528}
529
530std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
531SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
532 std::string name = getValuePackName(symbol: key).str();
533
534 return symbolInfoMap.equal_range(x: name);
535}
536
537int SymbolInfoMap::count(StringRef key) const {
538 std::string name = getValuePackName(symbol: key).str();
539 return symbolInfoMap.count(x: name);
540}
541
542int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
543 StringRef name = getValuePackName(symbol);
544 if (name != symbol) {
545 // If there is a trailing index inside symbol, it references just one
546 // static value.
547 return 1;
548 }
549 // Otherwise, find how many it represents by querying the symbol's info.
550 return find(key: name)->second.getStaticValueCount();
551}
552
553std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
554 const char *fmt,
555 const char *separator) const {
556 int index = -1;
557 StringRef name = getValuePackName(symbol, index: &index);
558
559 auto it = symbolInfoMap.find(x: name.str());
560 if (it == symbolInfoMap.end()) {
561 auto error = formatv(Fmt: "referencing unbound symbol '{0}'", Vals&: symbol);
562 PrintFatalError(ErrorLoc: loc, Msg: error);
563 }
564
565 return it->second.getValueAndRangeUse(name, index, fmt, separator);
566}
567
568std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
569 const char *separator) const {
570 int index = -1;
571 StringRef name = getValuePackName(symbol, index: &index);
572
573 auto it = symbolInfoMap.find(x: name.str());
574 if (it == symbolInfoMap.end()) {
575 auto error = formatv(Fmt: "referencing unbound symbol '{0}'", Vals&: symbol);
576 PrintFatalError(ErrorLoc: loc, Msg: error);
577 }
578
579 return it->second.getAllRangeUse(name, index, fmt, separator);
580}
581
582void SymbolInfoMap::assignUniqueAlternativeNames() {
583 llvm::StringSet<> usedNames;
584
585 for (auto symbolInfoIt = symbolInfoMap.begin();
586 symbolInfoIt != symbolInfoMap.end();) {
587 auto range = symbolInfoMap.equal_range(x: symbolInfoIt->first);
588 auto startRange = range.first;
589 auto endRange = range.second;
590
591 auto operandName = symbolInfoIt->first;
592 int startSearchIndex = 0;
593 for (++startRange; startRange != endRange; ++startRange) {
594 // Current operand name is not unique, find a unique one
595 // and set the alternative name.
596 for (int i = startSearchIndex;; ++i) {
597 std::string alternativeName = operandName + std::to_string(val: i);
598 if (!usedNames.contains(key: alternativeName) &&
599 symbolInfoMap.count(x: alternativeName) == 0) {
600 usedNames.insert(key: alternativeName);
601 startRange->second.alternativeName = alternativeName;
602 startSearchIndex = i + 1;
603
604 break;
605 }
606 }
607 }
608
609 symbolInfoIt = endRange;
610 }
611}
612
613//===----------------------------------------------------------------------===//
614// Pattern
615//==----------------------------------------------------------------------===//
616
617Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
618 : def(*def), recordOpMap(mapper) {}
619
620DagNode Pattern::getSourcePattern() const {
621 return DagNode(def.getValueAsDag(FieldName: "sourcePattern"));
622}
623
624int Pattern::getNumResultPatterns() const {
625 auto *results = def.getValueAsListInit(FieldName: "resultPatterns");
626 return results->size();
627}
628
629DagNode Pattern::getResultPattern(unsigned index) const {
630 auto *results = def.getValueAsListInit(FieldName: "resultPatterns");
631 return DagNode(cast<llvm::DagInit>(Val: results->getElement(i: index)));
632}
633
634void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
635 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
636 collectBoundSymbols(tree: getSourcePattern(), infoMap, /*isSrcPattern=*/true);
637 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
638
639 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
640 infoMap.assignUniqueAlternativeNames();
641 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
642}
643
644void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
645 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
646 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
647 auto pattern = getResultPattern(index: i);
648 collectBoundSymbols(tree: pattern, infoMap, /*isSrcPattern=*/false);
649 }
650 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
651}
652
653const Operator &Pattern::getSourceRootOp() {
654 return getSourcePattern().getDialectOp(mapper: recordOpMap);
655}
656
657Operator &Pattern::getDialectOp(DagNode node) {
658 return node.getDialectOp(mapper: recordOpMap);
659}
660
661std::vector<AppliedConstraint> Pattern::getConstraints() const {
662 auto *listInit = def.getValueAsListInit(FieldName: "constraints");
663 std::vector<AppliedConstraint> ret;
664 ret.reserve(n: listInit->size());
665
666 for (auto *it : *listInit) {
667 auto *dagInit = dyn_cast<llvm::DagInit>(Val: it);
668 if (!dagInit)
669 PrintFatalError(Rec: &def, Msg: "all elements in Pattern multi-entity "
670 "constraints should be DAG nodes");
671
672 std::vector<std::string> entities;
673 entities.reserve(n: dagInit->arg_size());
674 for (auto *argName : dagInit->getArgNames()) {
675 if (!argName) {
676 PrintFatalError(
677 Rec: &def,
678 Msg: "operands to additional constraints can only be symbol references");
679 }
680 entities.emplace_back(args: argName->getValue());
681 }
682
683 ret.emplace_back(args: cast<llvm::DefInit>(Val: dagInit->getOperator())->getDef(),
684 args: dagInit->getNameStr(), args: std::move(entities));
685 }
686 return ret;
687}
688
689int Pattern::getNumSupplementalPatterns() const {
690 auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns");
691 return results->size();
692}
693
694DagNode Pattern::getSupplementalPattern(unsigned index) const {
695 auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns");
696 return DagNode(cast<llvm::DagInit>(Val: results->getElement(i: index)));
697}
698
699int Pattern::getBenefit() const {
700 // The initial benefit value is a heuristic with number of ops in the source
701 // pattern.
702 int initBenefit = getSourcePattern().getNumOps();
703 llvm::DagInit *delta = def.getValueAsDag(FieldName: "benefitDelta");
704 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(Val: delta->getArg(Num: 0))) {
705 PrintFatalError(Rec: &def,
706 Msg: "The 'addBenefit' takes and only takes one integer value");
707 }
708 return initBenefit + dyn_cast<llvm::IntInit>(Val: delta->getArg(Num: 0))->getValue();
709}
710
711std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
712 std::vector<std::pair<StringRef, unsigned>> result;
713 result.reserve(n: def.getLoc().size());
714 for (auto loc : def.getLoc()) {
715 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(Loc: loc);
716 assert(buf && "invalid source location");
717 result.emplace_back(
718 args: llvm::SrcMgr.getBufferInfo(i: buf).Buffer->getBufferIdentifier(),
719 args: llvm::SrcMgr.getLineAndColumn(Loc: loc, BufferID: buf).first);
720 }
721 return result;
722}
723
724void Pattern::verifyBind(bool result, StringRef symbolName) {
725 if (!result) {
726 auto err = formatv(Fmt: "symbol '{0}' bound more than once", Vals&: symbolName);
727 PrintFatalError(Rec: &def, Msg: err);
728 }
729}
730
731void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
732 bool isSrcPattern) {
733 auto treeName = tree.getSymbol();
734 auto numTreeArgs = tree.getNumArgs();
735
736 if (tree.isNativeCodeCall()) {
737 if (!treeName.empty()) {
738 if (!isSrcPattern) {
739 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
740 << treeName << '\n');
741 verifyBind(
742 result: infoMap.bindValues(symbol: treeName, numValues: tree.getNumReturnsOfNativeCode()),
743 symbolName: treeName);
744 } else {
745 PrintFatalError(Rec: &def,
746 Msg: formatv(Fmt: "binding symbol '{0}' to NativecodeCall in "
747 "MatchPattern is not supported",
748 Vals&: treeName));
749 }
750 }
751
752 for (int i = 0; i != numTreeArgs; ++i) {
753 if (auto treeArg = tree.getArgAsNestedDag(index: i)) {
754 // This DAG node argument is a DAG node itself. Go inside recursively.
755 collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern);
756 continue;
757 }
758
759 if (!isSrcPattern)
760 continue;
761
762 // We can only bind symbols to arguments in source pattern. Those
763 // symbols are referenced in result patterns.
764 auto treeArgName = tree.getArgName(index: i);
765
766 // `$_` is a special symbol meaning ignore the current argument.
767 if (!treeArgName.empty() && treeArgName != "_") {
768 DagLeaf leaf = tree.getArgAsLeaf(index: i);
769
770 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
771 if (leaf.isUnspecified()) {
772 // This is case of $c, a Value without any constraints.
773 verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName);
774 } else {
775 auto constraint = leaf.getAsConstraint();
776 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
777 leaf.isConstantAttr() ||
778 constraint.getKind() == Constraint::Kind::CK_Attr;
779
780 if (isAttr) {
781 // This is case of $a, a binding to a certain attribute.
782 verifyBind(result: infoMap.bindAttr(symbol: treeArgName), symbolName: treeArgName);
783 continue;
784 }
785
786 // This is case of $b, a binding to a certain type.
787 verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName);
788 }
789 }
790 }
791
792 return;
793 }
794
795 if (tree.isOperation()) {
796 auto &op = getDialectOp(node: tree);
797 auto numOpArgs = op.getNumArgs();
798 int numEither = 0;
799
800 // We need to exclude the trailing directives and `either` directive groups
801 // two operands of the operation.
802 int numDirectives = 0;
803 for (int i = numTreeArgs - 1; i >= 0; --i) {
804 if (auto dagArg = tree.getArgAsNestedDag(index: i)) {
805 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
806 ++numDirectives;
807 else if (dagArg.isEither())
808 ++numEither;
809 }
810 }
811
812 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
813 auto err =
814 formatv(Fmt: "op '{0}' argument number mismatch: "
815 "{1} in pattern vs. {2} in definition",
816 Vals: op.getOperationName(), Vals: numTreeArgs + numEither, Vals&: numOpArgs);
817 PrintFatalError(Rec: &def, Msg: err);
818 }
819
820 // The name attached to the DAG node's operator is for representing the
821 // results generated from this op. It should be remembered as bound results.
822 if (!treeName.empty()) {
823 LLVM_DEBUG(llvm::dbgs()
824 << "found symbol bound to op result: " << treeName << '\n');
825 verifyBind(result: infoMap.bindOpResult(symbol: treeName, op), symbolName: treeName);
826 }
827
828 // The operand in `either` DAG should be bound to the operation in the
829 // parent DagNode.
830 auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
831 int opArgIdx) {
832 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
833 if (DagNode subTree = tree.getArgAsNestedDag(index: i)) {
834 collectBoundSymbols(tree: subTree, infoMap, isSrcPattern);
835 } else {
836 auto argName = tree.getArgName(index: i);
837 if (!argName.empty() && argName != "_") {
838 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx),
839 symbolName: argName);
840 }
841 }
842 }
843 };
844
845 // The operand in `variadic` DAG should be bound to the operation in the
846 // parent DagNode. The range index must be included as well to distinguish
847 // (potentially) repeating argName within the `variadic` DAG.
848 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
849 int opArgIdx) {
850 auto treeName = tree.getSymbol();
851 if (!treeName.empty()) {
852 // If treeName is specified, bind to the full variadic operand_range.
853 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: treeName, op, argIndex: opArgIdx,
854 variadicSubIndex: std::nullopt),
855 symbolName: treeName);
856 }
857
858 for (int i = 0; i < tree.getNumArgs(); ++i) {
859 if (DagNode subTree = tree.getArgAsNestedDag(index: i)) {
860 collectBoundSymbols(tree: subTree, infoMap, isSrcPattern);
861 } else {
862 auto argName = tree.getArgName(index: i);
863 if (!argName.empty() && argName != "_") {
864 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx,
865 /*variadicSubIndex=*/i),
866 symbolName: argName);
867 }
868 }
869 }
870 };
871
872 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
873 if (auto treeArg = tree.getArgAsNestedDag(index: i)) {
874 if (treeArg.isEither()) {
875 collectSymbolInEither(tree, treeArg, opArgIdx);
876 // `either` DAG is *flattened*. For example,
877 //
878 // (FooOp (either arg0, arg1), arg2)
879 //
880 // can be viewed as:
881 //
882 // (FooOp arg0, arg1, arg2)
883 ++opArgIdx;
884 } else if (treeArg.isVariadic()) {
885 collectSymbolInVariadic(tree, treeArg, opArgIdx);
886 } else {
887 // This DAG node argument is a DAG node itself. Go inside recursively.
888 collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern);
889 }
890 continue;
891 }
892
893 if (isSrcPattern) {
894 // We can only bind symbols to op arguments in source pattern. Those
895 // symbols are referenced in result patterns.
896 auto treeArgName = tree.getArgName(index: i);
897 // `$_` is a special symbol meaning ignore the current argument.
898 if (!treeArgName.empty() && treeArgName != "_") {
899 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
900 << treeArgName << '\n');
901 verifyBind(result: infoMap.bindOpArgument(node: tree, symbol: treeArgName, op, argIndex: opArgIdx),
902 symbolName: treeArgName);
903 }
904 }
905 }
906 return;
907 }
908
909 if (!treeName.empty()) {
910 PrintFatalError(
911 Rec: &def, Msg: formatv(Fmt: "binding symbol '{0}' to non-operation/native code call "
912 "unsupported right now",
913 Vals&: treeName));
914 }
915}
916

source code of mlir/lib/TableGen/Pattern.cpp