1//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pattern wrapper class to simplify using TableGen Record defining a MLIR
10// Pattern.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
16#include "mlir/TableGen/Pattern.h"
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/TableGen/Error.h"
22#include "llvm/TableGen/Record.h"
23
24#define DEBUG_TYPE "mlir-tblgen-pattern"
25
26using namespace mlir;
27using namespace tblgen;
28
29using llvm::DagInit;
30using llvm::dbgs;
31using llvm::DefInit;
32using llvm::formatv;
33using llvm::IntInit;
34using llvm::Record;
35
36//===----------------------------------------------------------------------===//
37// DagLeaf
38//===----------------------------------------------------------------------===//
39
40bool DagLeaf::isUnspecified() const {
41 return isa_and_nonnull<llvm::UnsetInit>(Val: def);
42}
43
44bool DagLeaf::isOperandMatcher() const {
45 // Operand matchers specify a type constraint.
46 return isSubClassOf(superclass: "TypeConstraint");
47}
48
49bool DagLeaf::isAttrMatcher() const {
50 // Attribute matchers specify an attribute constraint.
51 return isSubClassOf(superclass: "AttrConstraint");
52}
53
54bool DagLeaf::isPropMatcher() const {
55 // Property matchers specify a property constraint.
56 return isSubClassOf(superclass: "PropConstraint");
57}
58
59bool DagLeaf::isPropDefinition() const {
60 // Property matchers specify a property definition.
61 return isSubClassOf(superclass: "Property");
62}
63
64bool DagLeaf::isNativeCodeCall() const {
65 return isSubClassOf(superclass: "NativeCodeCall");
66}
67
68bool DagLeaf::isConstantAttr() const { return isSubClassOf(superclass: "ConstantAttr"); }
69
70bool DagLeaf::isEnumCase() const { return isSubClassOf(superclass: "EnumCase"); }
71
72bool DagLeaf::isConstantProp() const { return isSubClassOf(superclass: "ConstantProp"); }
73
74bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(Val: def); }
75
76Constraint DagLeaf::getAsConstraint() const {
77 assert((isOperandMatcher() || isAttrMatcher() || isPropMatcher()) &&
78 "the DAG leaf must be operand, attribute, or property");
79 return Constraint(cast<DefInit>(Val: def)->getDef());
80}
81
82PropConstraint DagLeaf::getAsPropConstraint() const {
83 assert(isPropMatcher() && "the DAG leaf must be a property matcher");
84 return PropConstraint(cast<DefInit>(Val: def)->getDef());
85}
86
87Property DagLeaf::getAsProperty() const {
88 assert(isPropDefinition() && "the DAG leaf must be a property definition");
89 return Property(cast<DefInit>(Val: def)->getDef());
90}
91
92ConstantAttr DagLeaf::getAsConstantAttr() const {
93 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
94 return ConstantAttr(cast<DefInit>(Val: def));
95}
96
97EnumCase DagLeaf::getAsEnumCase() const {
98 assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
99 return EnumCase(cast<DefInit>(Val: def));
100}
101
102ConstantProp DagLeaf::getAsConstantProp() const {
103 assert(isConstantProp() && "the DAG leaf must be a constant property value");
104 return ConstantProp(cast<DefInit>(Val: def));
105}
106
107std::string DagLeaf::getConditionTemplate() const {
108 return getAsConstraint().getConditionTemplate();
109}
110
111StringRef DagLeaf::getNativeCodeTemplate() const {
112 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
113 return cast<DefInit>(Val: def)->getDef()->getValueAsString(FieldName: "expression");
114}
115
116int DagLeaf::getNumReturnsOfNativeCode() const {
117 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
118 return cast<DefInit>(Val: def)->getDef()->getValueAsInt(FieldName: "numReturns");
119}
120
121std::string DagLeaf::getStringAttr() const {
122 assert(isStringAttr() && "the DAG leaf must be string attribute");
123 return def->getAsUnquotedString();
124}
125bool DagLeaf::isSubClassOf(StringRef superclass) const {
126 if (auto *defInit = dyn_cast_or_null<DefInit>(Val: def))
127 return defInit->getDef()->isSubClassOf(Name: superclass);
128 return false;
129}
130
131void DagLeaf::print(raw_ostream &os) const {
132 if (def)
133 def->print(OS&: os);
134}
135
136//===----------------------------------------------------------------------===//
137// DagNode
138//===----------------------------------------------------------------------===//
139
140bool DagNode::isNativeCodeCall() const {
141 if (auto *defInit = dyn_cast_or_null<DefInit>(Val: node->getOperator()))
142 return defInit->getDef()->isSubClassOf(Name: "NativeCodeCall");
143 return false;
144}
145
146bool DagNode::isOperation() const {
147 return !isNativeCodeCall() && !isReplaceWithValue() &&
148 !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
149 !isVariadic();
150}
151
152StringRef DagNode::getNativeCodeTemplate() const {
153 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
154 return cast<DefInit>(Val: node->getOperator())
155 ->getDef()
156 ->getValueAsString(FieldName: "expression");
157}
158
159int DagNode::getNumReturnsOfNativeCode() const {
160 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
161 return cast<DefInit>(Val: node->getOperator())
162 ->getDef()
163 ->getValueAsInt(FieldName: "numReturns");
164}
165
166StringRef DagNode::getSymbol() const { return node->getNameStr(); }
167
168Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
169 const Record *opDef = cast<DefInit>(Val: node->getOperator())->getDef();
170 auto [it, inserted] = mapper->try_emplace(Key: opDef);
171 if (inserted)
172 it->second = std::make_unique<Operator>(args&: opDef);
173 return *it->second;
174}
175
176int DagNode::getNumOps() const {
177 // We want to get number of operations recursively involved in the DAG tree.
178 // All other directives should be excluded.
179 int count = isOperation() ? 1 : 0;
180 for (int i = 0, e = getNumArgs(); i != e; ++i) {
181 if (auto child = getArgAsNestedDag(index: i))
182 count += child.getNumOps();
183 }
184 return count;
185}
186
187int DagNode::getNumArgs() const { return node->getNumArgs(); }
188
189bool DagNode::isNestedDagArg(unsigned index) const {
190 return isa<DagInit>(Val: node->getArg(Num: index));
191}
192
193DagNode DagNode::getArgAsNestedDag(unsigned index) const {
194 return DagNode(dyn_cast_or_null<DagInit>(Val: node->getArg(Num: index)));
195}
196
197DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
198 assert(!isNestedDagArg(index));
199 return DagLeaf(node->getArg(Num: index));
200}
201
202StringRef DagNode::getArgName(unsigned index) const {
203 return node->getArgNameStr(Num: index);
204}
205
206bool DagNode::isReplaceWithValue() const {
207 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
208 return dagOpDef->getName() == "replaceWithValue";
209}
210
211bool DagNode::isLocationDirective() const {
212 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
213 return dagOpDef->getName() == "location";
214}
215
216bool DagNode::isReturnTypeDirective() const {
217 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
218 return dagOpDef->getName() == "returnType";
219}
220
221bool DagNode::isEither() const {
222 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
223 return dagOpDef->getName() == "either";
224}
225
226bool DagNode::isVariadic() const {
227 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
228 return dagOpDef->getName() == "variadic";
229}
230
231void DagNode::print(raw_ostream &os) const {
232 if (node)
233 node->print(OS&: os);
234}
235
236//===----------------------------------------------------------------------===//
237// SymbolInfoMap
238//===----------------------------------------------------------------------===//
239
240StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
241 int idx = -1;
242 auto [name, indexStr] = symbol.rsplit(Separator: "__");
243
244 if (indexStr.consumeInteger(Radix: 10, Result&: idx)) {
245 // The second part is not an index; we return the whole symbol as-is.
246 return symbol;
247 }
248 if (index) {
249 *index = idx;
250 }
251 return name;
252}
253
254SymbolInfoMap::SymbolInfo::SymbolInfo(
255 const Operator *op, SymbolInfo::Kind kind,
256 std::optional<DagAndConstant> dagAndConstant)
257 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
258
259int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
260 switch (kind) {
261 case Kind::Attr:
262 case Kind::Prop:
263 case Kind::Operand:
264 case Kind::Value:
265 return 1;
266 case Kind::Result:
267 return op->getNumResults();
268 case Kind::MultipleValues:
269 return getSize();
270 }
271 llvm_unreachable("unknown kind");
272}
273
274std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
275 return alternativeName ? *alternativeName : name.str();
276}
277
278std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
279 LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
280 switch (kind) {
281 case Kind::Attr: {
282 if (op)
283 return cast<NamedAttribute *>(Val: op->getArg(index: getArgIndex()))
284 ->attr.getStorageType()
285 .str();
286 // TODO(suderman): Use a more exact type when available.
287 return "::mlir::Attribute";
288 }
289 case Kind::Prop: {
290 if (op)
291 return cast<NamedProperty *>(Val: op->getArg(index: getArgIndex()))
292 ->prop.getInterfaceType()
293 .str();
294 assert(dagAndConstant && dagAndConstant->dag &&
295 "generic properties must carry their constraint");
296 return reinterpret_cast<const DagLeaf *>(dagAndConstant->dag)
297 ->getAsPropConstraint()
298 .getInterfaceType()
299 .str();
300 }
301 case Kind::Operand: {
302 // Use operand range for captured operands (to support potential variadic
303 // operands).
304 return "::mlir::Operation::operand_range";
305 }
306 case Kind::Value: {
307 return "::mlir::Value";
308 }
309 case Kind::MultipleValues: {
310 return "::mlir::ValueRange";
311 }
312 case Kind::Result: {
313 // Use the op itself for captured results.
314 return op->getQualCppClassName();
315 }
316 }
317 llvm_unreachable("unknown kind");
318}
319
320std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
321 LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
322 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
323 return std::string(
324 formatv(Fmt: "{0} {1}{2};\n", Vals: getVarTypeStr(name), Vals: getVarName(name), Vals&: varInit));
325}
326
327std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
328 LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
329 return std::string(
330 formatv(Fmt: "{0} &{1}", Vals: getVarTypeStr(name), Vals: getVarName(name)));
331}
332
333std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
334 StringRef name, int index, const char *fmt, const char *separator) const {
335 LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
336 switch (kind) {
337 case Kind::Attr: {
338 assert(index < 0);
339 auto repl = formatv(Fmt: fmt, Vals&: name);
340 LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
341 return std::string(repl);
342 }
343 case Kind::Prop: {
344 assert(index < 0);
345 auto repl = formatv(Fmt: fmt, Vals&: name);
346 LLVM_DEBUG(dbgs() << repl << " (Prop)\n");
347 return std::string(repl);
348 }
349 case Kind::Operand: {
350 assert(index < 0);
351 auto *operand = cast<NamedTypeConstraint *>(Val: op->getArg(index: getArgIndex()));
352 if (operand->isOptional()) {
353 auto repl = formatv(
354 Fmt: fmt, Vals: formatv(Fmt: "({0}.empty() ? ::mlir::Value() : *{0}.begin())", Vals&: name));
355 LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n");
356 return std::string(repl);
357 }
358 // If this operand is variadic and this SymbolInfo doesn't have a range
359 // index, then return the full variadic operand_range. Otherwise, return
360 // the value itself.
361 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
362 auto repl = formatv(Fmt: fmt, Vals&: name);
363 LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
364 return std::string(repl);
365 }
366 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "(*{0}.begin())", Vals&: name));
367 LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
368 return std::string(repl);
369 }
370 case Kind::Result: {
371 // If `index` is greater than zero, then we are referencing a specific
372 // result of a multi-result op. The result can still be variadic.
373 if (index >= 0) {
374 std::string v =
375 std::string(formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: index));
376 if (!op->getResult(index).isVariadic())
377 v = std::string(formatv(Fmt: "(*{0}.begin())", Vals&: v));
378 auto repl = formatv(Fmt: fmt, Vals&: v);
379 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
380 return std::string(repl);
381 }
382
383 // If this op has no result at all but still we bind a symbol to it, it
384 // means we want to capture the op itself.
385 if (op->getNumResults() == 0) {
386 LLVM_DEBUG(dbgs() << name << " (Op)\n");
387 return formatv(Fmt: fmt, Vals&: name);
388 }
389
390 // We are referencing all results of the multi-result op. A specific result
391 // can either be a value or a range. Then join them with `separator`.
392 SmallVector<std::string, 4> values;
393 values.reserve(N: op->getNumResults());
394
395 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
396 std::string v = std::string(formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: i));
397 if (!op->getResult(index: i).isVariadic()) {
398 v = std::string(formatv(Fmt: "(*{0}.begin())", Vals&: v));
399 }
400 values.push_back(Elt: std::string(formatv(Fmt: fmt, Vals&: v)));
401 }
402 auto repl = llvm::join(R&: values, Separator: separator);
403 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
404 return repl;
405 }
406 case Kind::Value: {
407 assert(index < 0);
408 assert(op == nullptr);
409 auto repl = formatv(Fmt: fmt, Vals&: name);
410 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
411 return std::string(repl);
412 }
413 case Kind::MultipleValues: {
414 assert(op == nullptr);
415 assert(index < getSize());
416 if (index >= 0) {
417 std::string repl =
418 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]", Vals&: name, Vals&: index)));
419 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
420 return repl;
421 }
422 // If it doesn't specify certain element, unpack them all.
423 auto repl =
424 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()", Vals&: name)));
425 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
426 return std::string(repl);
427 }
428 }
429 llvm_unreachable("unknown kind");
430}
431
432std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
433 StringRef name, int index, const char *fmt, const char *separator) const {
434 LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
435 switch (kind) {
436 case Kind::Attr:
437 case Kind::Prop:
438 case Kind::Operand: {
439 assert(index < 0 && "only allowed for symbol bound to result");
440 auto repl = formatv(Fmt: fmt, Vals&: name);
441 LLVM_DEBUG(dbgs() << repl << " (Operand/Attr/Prop)\n");
442 return std::string(repl);
443 }
444 case Kind::Result: {
445 if (index >= 0) {
446 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: index));
447 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
448 return std::string(repl);
449 }
450
451 // We are referencing all results of the multi-result op. Each result should
452 // have a value range, and then join them with `separator`.
453 SmallVector<std::string, 4> values;
454 values.reserve(N: op->getNumResults());
455
456 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
457 values.push_back(Elt: std::string(
458 formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: i))));
459 }
460 auto repl = llvm::join(R&: values, Separator: separator);
461 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
462 return repl;
463 }
464 case Kind::Value: {
465 assert(index < 0 && "only allowed for symbol bound to result");
466 assert(op == nullptr);
467 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{{{0}}", Vals&: name));
468 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
469 return std::string(repl);
470 }
471 case Kind::MultipleValues: {
472 assert(op == nullptr);
473 assert(index < getSize());
474 if (index >= 0) {
475 std::string repl =
476 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]", Vals&: name, Vals&: index)));
477 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
478 return repl;
479 }
480 auto repl =
481 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()", Vals&: name)));
482 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
483 return std::string(repl);
484 }
485 }
486 llvm_unreachable("unknown kind");
487}
488
489bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
490 const Operator &op, int argIndex,
491 std::optional<int> variadicSubIndex) {
492 StringRef name = getValuePackName(symbol);
493 if (name != symbol) {
494 auto error = formatv(
495 Fmt: "symbol '{0}' with trailing index cannot bind to op argument", Vals&: symbol);
496 PrintFatalError(ErrorLoc: loc, Msg: error);
497 }
498
499 Argument arg = op.getArg(index: argIndex);
500 SymbolInfo symInfo =
501 isa<NamedAttribute *>(Val: arg) ? SymbolInfo::getAttr(op: &op, index: argIndex)
502 : isa<NamedProperty *>(Val: arg)
503 ? SymbolInfo::getProp(op: &op, index: argIndex)
504 : SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex);
505
506 std::string key = symbol.str();
507 if (symbolInfoMap.count(x: key)) {
508 // Only non unique name for the operand is supported.
509 if (symInfo.kind != SymbolInfo::Kind::Operand) {
510 return false;
511 }
512
513 // Cannot add new operand if there is already non operand with the same
514 // name.
515 if (symbolInfoMap.find(x: key)->second.kind != SymbolInfo::Kind::Operand) {
516 return false;
517 }
518 }
519
520 symbolInfoMap.emplace(args&: key, args&: symInfo);
521 return true;
522}
523
524bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
525 std::string name = getValuePackName(symbol).str();
526 auto inserted = symbolInfoMap.emplace(args&: name, args: SymbolInfo::getResult(op: &op));
527
528 return symbolInfoMap.count(x: inserted->first) == 1;
529}
530
531bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
532 std::string name = getValuePackName(symbol).str();
533 if (numValues > 1)
534 return bindMultipleValues(symbol: name, numValues);
535 return bindValue(symbol: name);
536}
537
538bool SymbolInfoMap::bindValue(StringRef symbol) {
539 auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getValue());
540 return symbolInfoMap.count(x: inserted->first) == 1;
541}
542
543bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
544 std::string name = getValuePackName(symbol).str();
545 auto inserted =
546 symbolInfoMap.emplace(args&: name, args: SymbolInfo::getMultipleValues(numValues));
547 return symbolInfoMap.count(x: inserted->first) == 1;
548}
549
550bool SymbolInfoMap::bindAttr(StringRef symbol) {
551 auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getAttr());
552 return symbolInfoMap.count(x: inserted->first) == 1;
553}
554
555bool SymbolInfoMap::bindProp(StringRef symbol,
556 const PropConstraint &constraint) {
557 auto inserted =
558 symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getProp(constraint: &constraint));
559 return symbolInfoMap.count(x: inserted->first) == 1;
560}
561
562bool SymbolInfoMap::contains(StringRef symbol) const {
563 return find(key: symbol) != symbolInfoMap.end();
564}
565
566SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
567 std::string name = getValuePackName(symbol: key).str();
568
569 return symbolInfoMap.find(x: name);
570}
571
572SymbolInfoMap::const_iterator
573SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
574 int argIndex,
575 std::optional<int> variadicSubIndex) const {
576 return findBoundSymbol(
577 key, symbolInfo: SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex));
578}
579
580SymbolInfoMap::const_iterator
581SymbolInfoMap::findBoundSymbol(StringRef key,
582 const SymbolInfo &symbolInfo) const {
583 std::string name = getValuePackName(symbol: key).str();
584 auto range = symbolInfoMap.equal_range(x: name);
585
586 for (auto it = range.first; it != range.second; ++it)
587 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
588 return it;
589
590 return symbolInfoMap.end();
591}
592
593std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
594SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
595 std::string name = getValuePackName(symbol: key).str();
596
597 return symbolInfoMap.equal_range(x: name);
598}
599
600int SymbolInfoMap::count(StringRef key) const {
601 std::string name = getValuePackName(symbol: key).str();
602 return symbolInfoMap.count(x: name);
603}
604
605int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
606 StringRef name = getValuePackName(symbol);
607 if (name != symbol) {
608 // If there is a trailing index inside symbol, it references just one
609 // static value.
610 return 1;
611 }
612 // Otherwise, find how many it represents by querying the symbol's info.
613 return find(key: name)->second.getStaticValueCount();
614}
615
616std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
617 const char *fmt,
618 const char *separator) const {
619 int index = -1;
620 StringRef name = getValuePackName(symbol, index: &index);
621
622 auto it = symbolInfoMap.find(x: name.str());
623 if (it == symbolInfoMap.end()) {
624 auto error = formatv(Fmt: "referencing unbound symbol '{0}'", Vals&: symbol);
625 PrintFatalError(ErrorLoc: loc, Msg: error);
626 }
627
628 return it->second.getValueAndRangeUse(name, index, fmt, separator);
629}
630
631std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
632 const char *separator) const {
633 int index = -1;
634 StringRef name = getValuePackName(symbol, index: &index);
635
636 auto it = symbolInfoMap.find(x: name.str());
637 if (it == symbolInfoMap.end()) {
638 auto error = formatv(Fmt: "referencing unbound symbol '{0}'", Vals&: symbol);
639 PrintFatalError(ErrorLoc: loc, Msg: error);
640 }
641
642 return it->second.getAllRangeUse(name, index, fmt, separator);
643}
644
645void SymbolInfoMap::assignUniqueAlternativeNames() {
646 llvm::StringSet<> usedNames;
647
648 for (auto symbolInfoIt = symbolInfoMap.begin();
649 symbolInfoIt != symbolInfoMap.end();) {
650 auto range = symbolInfoMap.equal_range(x: symbolInfoIt->first);
651 auto startRange = range.first;
652 auto endRange = range.second;
653
654 auto operandName = symbolInfoIt->first;
655 int startSearchIndex = 0;
656 for (++startRange; startRange != endRange; ++startRange) {
657 // Current operand name is not unique, find a unique one
658 // and set the alternative name.
659 for (int i = startSearchIndex;; ++i) {
660 std::string alternativeName = operandName + std::to_string(val: i);
661 if (!usedNames.contains(key: alternativeName) &&
662 symbolInfoMap.count(x: alternativeName) == 0) {
663 usedNames.insert(key: alternativeName);
664 startRange->second.alternativeName = alternativeName;
665 startSearchIndex = i + 1;
666
667 break;
668 }
669 }
670 }
671
672 symbolInfoIt = endRange;
673 }
674}
675
676//===----------------------------------------------------------------------===//
677// Pattern
678//==----------------------------------------------------------------------===//
679
680Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
681 : def(*def), recordOpMap(mapper) {}
682
683DagNode Pattern::getSourcePattern() const {
684 return DagNode(def.getValueAsDag(FieldName: "sourcePattern"));
685}
686
687int Pattern::getNumResultPatterns() const {
688 auto *results = def.getValueAsListInit(FieldName: "resultPatterns");
689 return results->size();
690}
691
692DagNode Pattern::getResultPattern(unsigned index) const {
693 auto *results = def.getValueAsListInit(FieldName: "resultPatterns");
694 return DagNode(cast<DagInit>(Val: results->getElement(Idx: index)));
695}
696
697void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
698 LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
699 collectBoundSymbols(tree: getSourcePattern(), infoMap, /*isSrcPattern=*/true);
700 LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
701
702 LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
703 infoMap.assignUniqueAlternativeNames();
704 LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
705}
706
707void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
708 LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
709 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
710 auto pattern = getResultPattern(index: i);
711 collectBoundSymbols(tree: pattern, infoMap, /*isSrcPattern=*/false);
712 }
713 LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
714}
715
716const Operator &Pattern::getSourceRootOp() {
717 return getSourcePattern().getDialectOp(mapper: recordOpMap);
718}
719
720Operator &Pattern::getDialectOp(DagNode node) {
721 return node.getDialectOp(mapper: recordOpMap);
722}
723
724std::vector<AppliedConstraint> Pattern::getConstraints() const {
725 auto *listInit = def.getValueAsListInit(FieldName: "constraints");
726 std::vector<AppliedConstraint> ret;
727 ret.reserve(n: listInit->size());
728
729 for (auto *it : *listInit) {
730 auto *dagInit = dyn_cast<DagInit>(Val: it);
731 if (!dagInit)
732 PrintFatalError(Rec: &def, Msg: "all elements in Pattern multi-entity "
733 "constraints should be DAG nodes");
734
735 std::vector<std::string> entities;
736 entities.reserve(n: dagInit->arg_size());
737 for (auto *argName : dagInit->getArgNames()) {
738 if (!argName) {
739 PrintFatalError(
740 Rec: &def,
741 Msg: "operands to additional constraints can only be symbol references");
742 }
743 entities.emplace_back(args: argName->getValue());
744 }
745
746 ret.emplace_back(args: cast<DefInit>(Val: dagInit->getOperator())->getDef(),
747 args: dagInit->getNameStr(), args: std::move(entities));
748 }
749 return ret;
750}
751
752int Pattern::getNumSupplementalPatterns() const {
753 auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns");
754 return results->size();
755}
756
757DagNode Pattern::getSupplementalPattern(unsigned index) const {
758 auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns");
759 return DagNode(cast<DagInit>(Val: results->getElement(Idx: index)));
760}
761
762int Pattern::getBenefit() const {
763 // The initial benefit value is a heuristic with number of ops in the source
764 // pattern.
765 int initBenefit = getSourcePattern().getNumOps();
766 const DagInit *delta = def.getValueAsDag(FieldName: "benefitDelta");
767 if (delta->getNumArgs() != 1 || !isa<IntInit>(Val: delta->getArg(Num: 0))) {
768 PrintFatalError(Rec: &def,
769 Msg: "The 'addBenefit' takes and only takes one integer value");
770 }
771 return initBenefit + dyn_cast<IntInit>(Val: delta->getArg(Num: 0))->getValue();
772}
773
774std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
775 std::vector<std::pair<StringRef, unsigned>> result;
776 result.reserve(n: def.getLoc().size());
777 for (auto loc : def.getLoc()) {
778 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(Loc: loc);
779 assert(buf && "invalid source location");
780 result.emplace_back(
781 args: llvm::SrcMgr.getBufferInfo(i: buf).Buffer->getBufferIdentifier(),
782 args: llvm::SrcMgr.getLineAndColumn(Loc: loc, BufferID: buf).first);
783 }
784 return result;
785}
786
787void Pattern::verifyBind(bool result, StringRef symbolName) {
788 if (!result) {
789 auto err = formatv(Fmt: "symbol '{0}' bound more than once", Vals&: symbolName);
790 PrintFatalError(Rec: &def, Msg: err);
791 }
792}
793
794void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
795 bool isSrcPattern) {
796 auto treeName = tree.getSymbol();
797 auto numTreeArgs = tree.getNumArgs();
798
799 if (tree.isNativeCodeCall()) {
800 if (!treeName.empty()) {
801 if (!isSrcPattern) {
802 LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
803 << treeName << '\n');
804 verifyBind(
805 result: infoMap.bindValues(symbol: treeName, numValues: tree.getNumReturnsOfNativeCode()),
806 symbolName: treeName);
807 } else {
808 PrintFatalError(Rec: &def,
809 Msg: formatv(Fmt: "binding symbol '{0}' to NativecodeCall in "
810 "MatchPattern is not supported",
811 Vals&: treeName));
812 }
813 }
814
815 for (int i = 0; i != numTreeArgs; ++i) {
816 if (auto treeArg = tree.getArgAsNestedDag(index: i)) {
817 // This DAG node argument is a DAG node itself. Go inside recursively.
818 collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern);
819 continue;
820 }
821
822 if (!isSrcPattern)
823 continue;
824
825 // We can only bind symbols to arguments in source pattern. Those
826 // symbols are referenced in result patterns.
827 auto treeArgName = tree.getArgName(index: i);
828
829 // `$_` is a special symbol meaning ignore the current argument.
830 if (!treeArgName.empty() && treeArgName != "_") {
831 DagLeaf leaf = tree.getArgAsLeaf(index: i);
832
833 // In (NativeCodeCall<"Foo($_self, $0, $1, $2, $3)"> I8Attr:$a, I8:$b,
834 // $c, I8Prop:$d),
835 if (leaf.isUnspecified()) {
836 // This is case of $c, a Value without any constraints.
837 verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName);
838 } else if (leaf.isPropMatcher()) {
839 // This is case of $d, a binding to a certain property.
840 auto propConstraint = leaf.getAsPropConstraint();
841 if (propConstraint.getInterfaceType().empty()) {
842 PrintFatalError(Rec: &def,
843 Msg: formatv(Fmt: "binding symbol '{0}' in NativeCodeCall to "
844 "a property constraint without specifying "
845 "that constraint's type is unsupported",
846 Vals&: treeArgName));
847 }
848 verifyBind(result: infoMap.bindProp(symbol: treeArgName, constraint: propConstraint),
849 symbolName: treeArgName);
850 } else {
851 auto constraint = leaf.getAsConstraint();
852 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
853 leaf.isConstantAttr() ||
854 constraint.getKind() == Constraint::Kind::CK_Attr;
855
856 if (isAttr) {
857 // This is case of $a, a binding to a certain attribute.
858 verifyBind(result: infoMap.bindAttr(symbol: treeArgName), symbolName: treeArgName);
859 continue;
860 }
861
862 // This is case of $b, a binding to a certain type.
863 verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName);
864 }
865 }
866 }
867
868 return;
869 }
870
871 if (tree.isOperation()) {
872 auto &op = getDialectOp(node: tree);
873 auto numOpArgs = op.getNumArgs();
874 int numEither = 0;
875
876 // We need to exclude the trailing directives and `either` directive groups
877 // two operands of the operation.
878 int numDirectives = 0;
879 for (int i = numTreeArgs - 1; i >= 0; --i) {
880 if (auto dagArg = tree.getArgAsNestedDag(index: i)) {
881 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
882 ++numDirectives;
883 else if (dagArg.isEither())
884 ++numEither;
885 }
886 }
887
888 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
889 auto err =
890 formatv(Fmt: "op '{0}' argument number mismatch: "
891 "{1} in pattern vs. {2} in definition",
892 Vals: op.getOperationName(), Vals: numTreeArgs + numEither, Vals&: numOpArgs);
893 PrintFatalError(Rec: &def, Msg: err);
894 }
895
896 // The name attached to the DAG node's operator is for representing the
897 // results generated from this op. It should be remembered as bound results.
898 if (!treeName.empty()) {
899 LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
900 << '\n');
901 verifyBind(result: infoMap.bindOpResult(symbol: treeName, op), symbolName: treeName);
902 }
903
904 // The operand in `either` DAG should be bound to the operation in the
905 // parent DagNode.
906 auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
907 int opArgIdx) {
908 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
909 if (DagNode subTree = tree.getArgAsNestedDag(index: i)) {
910 collectBoundSymbols(tree: subTree, infoMap, isSrcPattern);
911 } else {
912 auto argName = tree.getArgName(index: i);
913 if (!argName.empty() && argName != "_") {
914 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx),
915 symbolName: argName);
916 }
917 }
918 }
919 };
920
921 // The operand in `variadic` DAG should be bound to the operation in the
922 // parent DagNode. The range index must be included as well to distinguish
923 // (potentially) repeating argName within the `variadic` DAG.
924 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
925 int opArgIdx) {
926 auto treeName = tree.getSymbol();
927 if (!treeName.empty()) {
928 // If treeName is specified, bind to the full variadic operand_range.
929 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: treeName, op, argIndex: opArgIdx,
930 variadicSubIndex: std::nullopt),
931 symbolName: treeName);
932 }
933
934 for (int i = 0; i < tree.getNumArgs(); ++i) {
935 if (DagNode subTree = tree.getArgAsNestedDag(index: i)) {
936 collectBoundSymbols(tree: subTree, infoMap, isSrcPattern);
937 } else {
938 auto argName = tree.getArgName(index: i);
939 if (!argName.empty() && argName != "_") {
940 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx,
941 /*variadicSubIndex=*/i),
942 symbolName: argName);
943 }
944 }
945 }
946 };
947
948 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
949 if (auto treeArg = tree.getArgAsNestedDag(index: i)) {
950 if (treeArg.isEither()) {
951 collectSymbolInEither(tree, treeArg, opArgIdx);
952 // `either` DAG is *flattened*. For example,
953 //
954 // (FooOp (either arg0, arg1), arg2)
955 //
956 // can be viewed as:
957 //
958 // (FooOp arg0, arg1, arg2)
959 ++opArgIdx;
960 } else if (treeArg.isVariadic()) {
961 collectSymbolInVariadic(tree, treeArg, opArgIdx);
962 } else {
963 // This DAG node argument is a DAG node itself. Go inside recursively.
964 collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern);
965 }
966 continue;
967 }
968
969 if (isSrcPattern) {
970 // We can only bind symbols to op arguments in source pattern. Those
971 // symbols are referenced in result patterns.
972 auto treeArgName = tree.getArgName(index: i);
973 // `$_` is a special symbol meaning ignore the current argument.
974 if (!treeArgName.empty() && treeArgName != "_") {
975 LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
976 << treeArgName << '\n');
977 verifyBind(result: infoMap.bindOpArgument(node: tree, symbol: treeArgName, op, argIndex: opArgIdx),
978 symbolName: treeArgName);
979 }
980 }
981 }
982 return;
983 }
984
985 if (!treeName.empty()) {
986 PrintFatalError(
987 Rec: &def, Msg: formatv(Fmt: "binding symbol '{0}' to non-operation/native code call "
988 "unsupported right now",
989 Vals&: treeName));
990 }
991}
992

source code of mlir/lib/TableGen/Pattern.cpp