1//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pattern wrapper class to simplify using TableGen Record defining a MLIR
10// Pattern.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
16#include "mlir/TableGen/Pattern.h"
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/TableGen/Error.h"
22#include "llvm/TableGen/Record.h"
23
24#define DEBUG_TYPE "mlir-tblgen-pattern"
25
26using namespace mlir;
27using namespace tblgen;
28
29using llvm::DagInit;
30using llvm::dbgs;
31using llvm::DefInit;
32using llvm::formatv;
33using llvm::IntInit;
34using llvm::Record;
35
36//===----------------------------------------------------------------------===//
37// DagLeaf
38//===----------------------------------------------------------------------===//
39
40bool DagLeaf::isUnspecified() const {
41 return isa_and_nonnull<llvm::UnsetInit>(Val: def);
42}
43
44bool DagLeaf::isOperandMatcher() const {
45 // Operand matchers specify a type constraint.
46 return isSubClassOf(superclass: "TypeConstraint");
47}
48
49bool DagLeaf::isAttrMatcher() const {
50 // Attribute matchers specify an attribute constraint.
51 return isSubClassOf(superclass: "AttrConstraint");
52}
53
54bool DagLeaf::isNativeCodeCall() const {
55 return isSubClassOf(superclass: "NativeCodeCall");
56}
57
58bool DagLeaf::isConstantAttr() const { return isSubClassOf(superclass: "ConstantAttr"); }
59
60bool DagLeaf::isEnumCase() const { return isSubClassOf(superclass: "EnumCase"); }
61
62bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(Val: def); }
63
64Constraint DagLeaf::getAsConstraint() const {
65 assert((isOperandMatcher() || isAttrMatcher()) &&
66 "the DAG leaf must be operand or attribute");
67 return Constraint(cast<DefInit>(Val: def)->getDef());
68}
69
70ConstantAttr DagLeaf::getAsConstantAttr() const {
71 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
72 return ConstantAttr(cast<DefInit>(Val: def));
73}
74
75EnumCase DagLeaf::getAsEnumCase() const {
76 assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
77 return EnumCase(cast<DefInit>(Val: def));
78}
79
80std::string DagLeaf::getConditionTemplate() const {
81 return getAsConstraint().getConditionTemplate();
82}
83
84StringRef DagLeaf::getNativeCodeTemplate() const {
85 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
86 return cast<DefInit>(Val: def)->getDef()->getValueAsString(FieldName: "expression");
87}
88
89int DagLeaf::getNumReturnsOfNativeCode() const {
90 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
91 return cast<DefInit>(Val: def)->getDef()->getValueAsInt(FieldName: "numReturns");
92}
93
94std::string DagLeaf::getStringAttr() const {
95 assert(isStringAttr() && "the DAG leaf must be string attribute");
96 return def->getAsUnquotedString();
97}
98bool DagLeaf::isSubClassOf(StringRef superclass) const {
99 if (auto *defInit = dyn_cast_or_null<DefInit>(Val: def))
100 return defInit->getDef()->isSubClassOf(Name: superclass);
101 return false;
102}
103
104void DagLeaf::print(raw_ostream &os) const {
105 if (def)
106 def->print(OS&: os);
107}
108
109//===----------------------------------------------------------------------===//
110// DagNode
111//===----------------------------------------------------------------------===//
112
113bool DagNode::isNativeCodeCall() const {
114 if (auto *defInit = dyn_cast_or_null<DefInit>(Val: node->getOperator()))
115 return defInit->getDef()->isSubClassOf(Name: "NativeCodeCall");
116 return false;
117}
118
119bool DagNode::isOperation() const {
120 return !isNativeCodeCall() && !isReplaceWithValue() &&
121 !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
122 !isVariadic();
123}
124
125StringRef DagNode::getNativeCodeTemplate() const {
126 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
127 return cast<DefInit>(Val: node->getOperator())
128 ->getDef()
129 ->getValueAsString(FieldName: "expression");
130}
131
132int DagNode::getNumReturnsOfNativeCode() const {
133 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
134 return cast<DefInit>(Val: node->getOperator())
135 ->getDef()
136 ->getValueAsInt(FieldName: "numReturns");
137}
138
139StringRef DagNode::getSymbol() const { return node->getNameStr(); }
140
141Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
142 const Record *opDef = cast<DefInit>(Val: node->getOperator())->getDef();
143 auto [it, inserted] = mapper->try_emplace(Key: opDef);
144 if (inserted)
145 it->second = std::make_unique<Operator>(args&: opDef);
146 return *it->second;
147}
148
149int DagNode::getNumOps() const {
150 // We want to get number of operations recursively involved in the DAG tree.
151 // All other directives should be excluded.
152 int count = isOperation() ? 1 : 0;
153 for (int i = 0, e = getNumArgs(); i != e; ++i) {
154 if (auto child = getArgAsNestedDag(index: i))
155 count += child.getNumOps();
156 }
157 return count;
158}
159
160int DagNode::getNumArgs() const { return node->getNumArgs(); }
161
162bool DagNode::isNestedDagArg(unsigned index) const {
163 return isa<DagInit>(Val: node->getArg(Num: index));
164}
165
166DagNode DagNode::getArgAsNestedDag(unsigned index) const {
167 return DagNode(dyn_cast_or_null<DagInit>(Val: node->getArg(Num: index)));
168}
169
170DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
171 assert(!isNestedDagArg(index));
172 return DagLeaf(node->getArg(Num: index));
173}
174
175StringRef DagNode::getArgName(unsigned index) const {
176 return node->getArgNameStr(Num: index);
177}
178
179bool DagNode::isReplaceWithValue() const {
180 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
181 return dagOpDef->getName() == "replaceWithValue";
182}
183
184bool DagNode::isLocationDirective() const {
185 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
186 return dagOpDef->getName() == "location";
187}
188
189bool DagNode::isReturnTypeDirective() const {
190 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
191 return dagOpDef->getName() == "returnType";
192}
193
194bool DagNode::isEither() const {
195 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
196 return dagOpDef->getName() == "either";
197}
198
199bool DagNode::isVariadic() const {
200 auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef();
201 return dagOpDef->getName() == "variadic";
202}
203
204void DagNode::print(raw_ostream &os) const {
205 if (node)
206 node->print(OS&: os);
207}
208
209//===----------------------------------------------------------------------===//
210// SymbolInfoMap
211//===----------------------------------------------------------------------===//
212
213StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
214 int idx = -1;
215 auto [name, indexStr] = symbol.rsplit(Separator: "__");
216
217 if (indexStr.consumeInteger(Radix: 10, Result&: idx)) {
218 // The second part is not an index; we return the whole symbol as-is.
219 return symbol;
220 }
221 if (index) {
222 *index = idx;
223 }
224 return name;
225}
226
227SymbolInfoMap::SymbolInfo::SymbolInfo(
228 const Operator *op, SymbolInfo::Kind kind,
229 std::optional<DagAndConstant> dagAndConstant)
230 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
231
232int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
233 switch (kind) {
234 case Kind::Attr:
235 case Kind::Operand:
236 case Kind::Value:
237 return 1;
238 case Kind::Result:
239 return op->getNumResults();
240 case Kind::MultipleValues:
241 return getSize();
242 }
243 llvm_unreachable("unknown kind");
244}
245
246std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
247 return alternativeName ? *alternativeName : name.str();
248}
249
250std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
251 LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
252 switch (kind) {
253 case Kind::Attr: {
254 if (op)
255 return cast<NamedAttribute *>(Val: op->getArg(index: getArgIndex()))
256 ->attr.getStorageType()
257 .str();
258 // TODO(suderman): Use a more exact type when available.
259 return "::mlir::Attribute";
260 }
261 case Kind::Operand: {
262 // Use operand range for captured operands (to support potential variadic
263 // operands).
264 return "::mlir::Operation::operand_range";
265 }
266 case Kind::Value: {
267 return "::mlir::Value";
268 }
269 case Kind::MultipleValues: {
270 return "::mlir::ValueRange";
271 }
272 case Kind::Result: {
273 // Use the op itself for captured results.
274 return op->getQualCppClassName();
275 }
276 }
277 llvm_unreachable("unknown kind");
278}
279
280std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
281 LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
282 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
283 return std::string(
284 formatv(Fmt: "{0} {1}{2};\n", Vals: getVarTypeStr(name), Vals: getVarName(name), Vals&: varInit));
285}
286
287std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
288 LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
289 return std::string(
290 formatv(Fmt: "{0} &{1}", Vals: getVarTypeStr(name), Vals: getVarName(name)));
291}
292
293std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
294 StringRef name, int index, const char *fmt, const char *separator) const {
295 LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
296 switch (kind) {
297 case Kind::Attr: {
298 assert(index < 0);
299 auto repl = formatv(Fmt: fmt, Vals&: name);
300 LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
301 return std::string(repl);
302 }
303 case Kind::Operand: {
304 assert(index < 0);
305 auto *operand = cast<NamedTypeConstraint *>(Val: op->getArg(index: getArgIndex()));
306 if (operand->isOptional()) {
307 auto repl = formatv(
308 Fmt: fmt, Vals: formatv(Fmt: "({0}.empty() ? ::mlir::Value() : *{0}.begin())", Vals&: name));
309 LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n");
310 return std::string(repl);
311 }
312 // If this operand is variadic and this SymbolInfo doesn't have a range
313 // index, then return the full variadic operand_range. Otherwise, return
314 // the value itself.
315 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
316 auto repl = formatv(Fmt: fmt, Vals&: name);
317 LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
318 return std::string(repl);
319 }
320 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "(*{0}.begin())", Vals&: name));
321 LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
322 return std::string(repl);
323 }
324 case Kind::Result: {
325 // If `index` is greater than zero, then we are referencing a specific
326 // result of a multi-result op. The result can still be variadic.
327 if (index >= 0) {
328 std::string v =
329 std::string(formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: index));
330 if (!op->getResult(index).isVariadic())
331 v = std::string(formatv(Fmt: "(*{0}.begin())", Vals&: v));
332 auto repl = formatv(Fmt: fmt, Vals&: v);
333 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
334 return std::string(repl);
335 }
336
337 // If this op has no result at all but still we bind a symbol to it, it
338 // means we want to capture the op itself.
339 if (op->getNumResults() == 0) {
340 LLVM_DEBUG(dbgs() << name << " (Op)\n");
341 return formatv(Fmt: fmt, Vals&: name);
342 }
343
344 // We are referencing all results of the multi-result op. A specific result
345 // can either be a value or a range. Then join them with `separator`.
346 SmallVector<std::string, 4> values;
347 values.reserve(N: op->getNumResults());
348
349 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
350 std::string v = std::string(formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: i));
351 if (!op->getResult(index: i).isVariadic()) {
352 v = std::string(formatv(Fmt: "(*{0}.begin())", Vals&: v));
353 }
354 values.push_back(Elt: std::string(formatv(Fmt: fmt, Vals&: v)));
355 }
356 auto repl = llvm::join(R&: values, Separator: separator);
357 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
358 return repl;
359 }
360 case Kind::Value: {
361 assert(index < 0);
362 assert(op == nullptr);
363 auto repl = formatv(Fmt: fmt, Vals&: name);
364 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
365 return std::string(repl);
366 }
367 case Kind::MultipleValues: {
368 assert(op == nullptr);
369 assert(index < getSize());
370 if (index >= 0) {
371 std::string repl =
372 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]", Vals&: name, Vals&: index)));
373 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
374 return repl;
375 }
376 // If it doesn't specify certain element, unpack them all.
377 auto repl =
378 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()", Vals&: name)));
379 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
380 return std::string(repl);
381 }
382 }
383 llvm_unreachable("unknown kind");
384}
385
386std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
387 StringRef name, int index, const char *fmt, const char *separator) const {
388 LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
389 switch (kind) {
390 case Kind::Attr:
391 case Kind::Operand: {
392 assert(index < 0 && "only allowed for symbol bound to result");
393 auto repl = formatv(Fmt: fmt, Vals&: name);
394 LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n");
395 return std::string(repl);
396 }
397 case Kind::Result: {
398 if (index >= 0) {
399 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: index));
400 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
401 return std::string(repl);
402 }
403
404 // We are referencing all results of the multi-result op. Each result should
405 // have a value range, and then join them with `separator`.
406 SmallVector<std::string, 4> values;
407 values.reserve(N: op->getNumResults());
408
409 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
410 values.push_back(Elt: std::string(
411 formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: i))));
412 }
413 auto repl = llvm::join(R&: values, Separator: separator);
414 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
415 return repl;
416 }
417 case Kind::Value: {
418 assert(index < 0 && "only allowed for symbol bound to result");
419 assert(op == nullptr);
420 auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{{{0}}", Vals&: name));
421 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
422 return std::string(repl);
423 }
424 case Kind::MultipleValues: {
425 assert(op == nullptr);
426 assert(index < getSize());
427 if (index >= 0) {
428 std::string repl =
429 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]", Vals&: name, Vals&: index)));
430 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
431 return repl;
432 }
433 auto repl =
434 formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()", Vals&: name)));
435 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
436 return std::string(repl);
437 }
438 }
439 llvm_unreachable("unknown kind");
440}
441
442bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
443 const Operator &op, int argIndex,
444 std::optional<int> variadicSubIndex) {
445 StringRef name = getValuePackName(symbol);
446 if (name != symbol) {
447 auto error = formatv(
448 Fmt: "symbol '{0}' with trailing index cannot bind to op argument", Vals&: symbol);
449 PrintFatalError(ErrorLoc: loc, Msg: error);
450 }
451
452 auto symInfo =
453 isa<NamedAttribute *>(Val: op.getArg(index: argIndex))
454 ? SymbolInfo::getAttr(op: &op, index: argIndex)
455 : SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex);
456
457 std::string key = symbol.str();
458 if (symbolInfoMap.count(x: key)) {
459 // Only non unique name for the operand is supported.
460 if (symInfo.kind != SymbolInfo::Kind::Operand) {
461 return false;
462 }
463
464 // Cannot add new operand if there is already non operand with the same
465 // name.
466 if (symbolInfoMap.find(x: key)->second.kind != SymbolInfo::Kind::Operand) {
467 return false;
468 }
469 }
470
471 symbolInfoMap.emplace(args&: key, args&: symInfo);
472 return true;
473}
474
475bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
476 std::string name = getValuePackName(symbol).str();
477 auto inserted = symbolInfoMap.emplace(args&: name, args: SymbolInfo::getResult(op: &op));
478
479 return symbolInfoMap.count(x: inserted->first) == 1;
480}
481
482bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
483 std::string name = getValuePackName(symbol).str();
484 if (numValues > 1)
485 return bindMultipleValues(symbol: name, numValues);
486 return bindValue(symbol: name);
487}
488
489bool SymbolInfoMap::bindValue(StringRef symbol) {
490 auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getValue());
491 return symbolInfoMap.count(x: inserted->first) == 1;
492}
493
494bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
495 std::string name = getValuePackName(symbol).str();
496 auto inserted =
497 symbolInfoMap.emplace(args&: name, args: SymbolInfo::getMultipleValues(numValues));
498 return symbolInfoMap.count(x: inserted->first) == 1;
499}
500
501bool SymbolInfoMap::bindAttr(StringRef symbol) {
502 auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getAttr());
503 return symbolInfoMap.count(x: inserted->first) == 1;
504}
505
506bool SymbolInfoMap::contains(StringRef symbol) const {
507 return find(key: symbol) != symbolInfoMap.end();
508}
509
510SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
511 std::string name = getValuePackName(symbol: key).str();
512
513 return symbolInfoMap.find(x: name);
514}
515
516SymbolInfoMap::const_iterator
517SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
518 int argIndex,
519 std::optional<int> variadicSubIndex) const {
520 return findBoundSymbol(
521 key, symbolInfo: SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex));
522}
523
524SymbolInfoMap::const_iterator
525SymbolInfoMap::findBoundSymbol(StringRef key,
526 const SymbolInfo &symbolInfo) const {
527 std::string name = getValuePackName(symbol: key).str();
528 auto range = symbolInfoMap.equal_range(x: name);
529
530 for (auto it = range.first; it != range.second; ++it)
531 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
532 return it;
533
534 return symbolInfoMap.end();
535}
536
537std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
538SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
539 std::string name = getValuePackName(symbol: key).str();
540
541 return symbolInfoMap.equal_range(x: name);
542}
543
544int SymbolInfoMap::count(StringRef key) const {
545 std::string name = getValuePackName(symbol: key).str();
546 return symbolInfoMap.count(x: name);
547}
548
549int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
550 StringRef name = getValuePackName(symbol);
551 if (name != symbol) {
552 // If there is a trailing index inside symbol, it references just one
553 // static value.
554 return 1;
555 }
556 // Otherwise, find how many it represents by querying the symbol's info.
557 return find(key: name)->second.getStaticValueCount();
558}
559
560std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
561 const char *fmt,
562 const char *separator) const {
563 int index = -1;
564 StringRef name = getValuePackName(symbol, index: &index);
565
566 auto it = symbolInfoMap.find(x: name.str());
567 if (it == symbolInfoMap.end()) {
568 auto error = formatv(Fmt: "referencing unbound symbol '{0}'", Vals&: symbol);
569 PrintFatalError(ErrorLoc: loc, Msg: error);
570 }
571
572 return it->second.getValueAndRangeUse(name, index, fmt, separator);
573}
574
575std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
576 const char *separator) const {
577 int index = -1;
578 StringRef name = getValuePackName(symbol, index: &index);
579
580 auto it = symbolInfoMap.find(x: name.str());
581 if (it == symbolInfoMap.end()) {
582 auto error = formatv(Fmt: "referencing unbound symbol '{0}'", Vals&: symbol);
583 PrintFatalError(ErrorLoc: loc, Msg: error);
584 }
585
586 return it->second.getAllRangeUse(name, index, fmt, separator);
587}
588
589void SymbolInfoMap::assignUniqueAlternativeNames() {
590 llvm::StringSet<> usedNames;
591
592 for (auto symbolInfoIt = symbolInfoMap.begin();
593 symbolInfoIt != symbolInfoMap.end();) {
594 auto range = symbolInfoMap.equal_range(x: symbolInfoIt->first);
595 auto startRange = range.first;
596 auto endRange = range.second;
597
598 auto operandName = symbolInfoIt->first;
599 int startSearchIndex = 0;
600 for (++startRange; startRange != endRange; ++startRange) {
601 // Current operand name is not unique, find a unique one
602 // and set the alternative name.
603 for (int i = startSearchIndex;; ++i) {
604 std::string alternativeName = operandName + std::to_string(val: i);
605 if (!usedNames.contains(key: alternativeName) &&
606 symbolInfoMap.count(x: alternativeName) == 0) {
607 usedNames.insert(key: alternativeName);
608 startRange->second.alternativeName = alternativeName;
609 startSearchIndex = i + 1;
610
611 break;
612 }
613 }
614 }
615
616 symbolInfoIt = endRange;
617 }
618}
619
620//===----------------------------------------------------------------------===//
621// Pattern
622//==----------------------------------------------------------------------===//
623
624Pattern::Pattern(const Record *def, RecordOperatorMap *mapper)
625 : def(*def), recordOpMap(mapper) {}
626
627DagNode Pattern::getSourcePattern() const {
628 return DagNode(def.getValueAsDag(FieldName: "sourcePattern"));
629}
630
631int Pattern::getNumResultPatterns() const {
632 auto *results = def.getValueAsListInit(FieldName: "resultPatterns");
633 return results->size();
634}
635
636DagNode Pattern::getResultPattern(unsigned index) const {
637 auto *results = def.getValueAsListInit(FieldName: "resultPatterns");
638 return DagNode(cast<DagInit>(Val: results->getElement(Idx: index)));
639}
640
641void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
642 LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
643 collectBoundSymbols(tree: getSourcePattern(), infoMap, /*isSrcPattern=*/true);
644 LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
645
646 LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
647 infoMap.assignUniqueAlternativeNames();
648 LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
649}
650
651void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
652 LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
653 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
654 auto pattern = getResultPattern(index: i);
655 collectBoundSymbols(tree: pattern, infoMap, /*isSrcPattern=*/false);
656 }
657 LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
658}
659
660const Operator &Pattern::getSourceRootOp() {
661 return getSourcePattern().getDialectOp(mapper: recordOpMap);
662}
663
664Operator &Pattern::getDialectOp(DagNode node) {
665 return node.getDialectOp(mapper: recordOpMap);
666}
667
668std::vector<AppliedConstraint> Pattern::getConstraints() const {
669 auto *listInit = def.getValueAsListInit(FieldName: "constraints");
670 std::vector<AppliedConstraint> ret;
671 ret.reserve(n: listInit->size());
672
673 for (auto *it : *listInit) {
674 auto *dagInit = dyn_cast<DagInit>(Val: it);
675 if (!dagInit)
676 PrintFatalError(Rec: &def, Msg: "all elements in Pattern multi-entity "
677 "constraints should be DAG nodes");
678
679 std::vector<std::string> entities;
680 entities.reserve(n: dagInit->arg_size());
681 for (auto *argName : dagInit->getArgNames()) {
682 if (!argName) {
683 PrintFatalError(
684 Rec: &def,
685 Msg: "operands to additional constraints can only be symbol references");
686 }
687 entities.emplace_back(args: argName->getValue());
688 }
689
690 ret.emplace_back(args: cast<DefInit>(Val: dagInit->getOperator())->getDef(),
691 args: dagInit->getNameStr(), args: std::move(entities));
692 }
693 return ret;
694}
695
696int Pattern::getNumSupplementalPatterns() const {
697 auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns");
698 return results->size();
699}
700
701DagNode Pattern::getSupplementalPattern(unsigned index) const {
702 auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns");
703 return DagNode(cast<DagInit>(Val: results->getElement(Idx: index)));
704}
705
706int Pattern::getBenefit() const {
707 // The initial benefit value is a heuristic with number of ops in the source
708 // pattern.
709 int initBenefit = getSourcePattern().getNumOps();
710 const DagInit *delta = def.getValueAsDag(FieldName: "benefitDelta");
711 if (delta->getNumArgs() != 1 || !isa<IntInit>(Val: delta->getArg(Num: 0))) {
712 PrintFatalError(Rec: &def,
713 Msg: "The 'addBenefit' takes and only takes one integer value");
714 }
715 return initBenefit + dyn_cast<IntInit>(Val: delta->getArg(Num: 0))->getValue();
716}
717
718std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
719 std::vector<std::pair<StringRef, unsigned>> result;
720 result.reserve(n: def.getLoc().size());
721 for (auto loc : def.getLoc()) {
722 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(Loc: loc);
723 assert(buf && "invalid source location");
724 result.emplace_back(
725 args: llvm::SrcMgr.getBufferInfo(i: buf).Buffer->getBufferIdentifier(),
726 args: llvm::SrcMgr.getLineAndColumn(Loc: loc, BufferID: buf).first);
727 }
728 return result;
729}
730
731void Pattern::verifyBind(bool result, StringRef symbolName) {
732 if (!result) {
733 auto err = formatv(Fmt: "symbol '{0}' bound more than once", Vals&: symbolName);
734 PrintFatalError(Rec: &def, Msg: err);
735 }
736}
737
738void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
739 bool isSrcPattern) {
740 auto treeName = tree.getSymbol();
741 auto numTreeArgs = tree.getNumArgs();
742
743 if (tree.isNativeCodeCall()) {
744 if (!treeName.empty()) {
745 if (!isSrcPattern) {
746 LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
747 << treeName << '\n');
748 verifyBind(
749 result: infoMap.bindValues(symbol: treeName, numValues: tree.getNumReturnsOfNativeCode()),
750 symbolName: treeName);
751 } else {
752 PrintFatalError(Rec: &def,
753 Msg: formatv(Fmt: "binding symbol '{0}' to NativecodeCall in "
754 "MatchPattern is not supported",
755 Vals&: treeName));
756 }
757 }
758
759 for (int i = 0; i != numTreeArgs; ++i) {
760 if (auto treeArg = tree.getArgAsNestedDag(index: i)) {
761 // This DAG node argument is a DAG node itself. Go inside recursively.
762 collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern);
763 continue;
764 }
765
766 if (!isSrcPattern)
767 continue;
768
769 // We can only bind symbols to arguments in source pattern. Those
770 // symbols are referenced in result patterns.
771 auto treeArgName = tree.getArgName(index: i);
772
773 // `$_` is a special symbol meaning ignore the current argument.
774 if (!treeArgName.empty() && treeArgName != "_") {
775 DagLeaf leaf = tree.getArgAsLeaf(index: i);
776
777 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
778 if (leaf.isUnspecified()) {
779 // This is case of $c, a Value without any constraints.
780 verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName);
781 } else {
782 auto constraint = leaf.getAsConstraint();
783 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
784 leaf.isConstantAttr() ||
785 constraint.getKind() == Constraint::Kind::CK_Attr;
786
787 if (isAttr) {
788 // This is case of $a, a binding to a certain attribute.
789 verifyBind(result: infoMap.bindAttr(symbol: treeArgName), symbolName: treeArgName);
790 continue;
791 }
792
793 // This is case of $b, a binding to a certain type.
794 verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName);
795 }
796 }
797 }
798
799 return;
800 }
801
802 if (tree.isOperation()) {
803 auto &op = getDialectOp(node: tree);
804 auto numOpArgs = op.getNumArgs();
805 int numEither = 0;
806
807 // We need to exclude the trailing directives and `either` directive groups
808 // two operands of the operation.
809 int numDirectives = 0;
810 for (int i = numTreeArgs - 1; i >= 0; --i) {
811 if (auto dagArg = tree.getArgAsNestedDag(index: i)) {
812 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
813 ++numDirectives;
814 else if (dagArg.isEither())
815 ++numEither;
816 }
817 }
818
819 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
820 auto err =
821 formatv(Fmt: "op '{0}' argument number mismatch: "
822 "{1} in pattern vs. {2} in definition",
823 Vals: op.getOperationName(), Vals: numTreeArgs + numEither, Vals&: numOpArgs);
824 PrintFatalError(Rec: &def, Msg: err);
825 }
826
827 // The name attached to the DAG node's operator is for representing the
828 // results generated from this op. It should be remembered as bound results.
829 if (!treeName.empty()) {
830 LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
831 << '\n');
832 verifyBind(result: infoMap.bindOpResult(symbol: treeName, op), symbolName: treeName);
833 }
834
835 // The operand in `either` DAG should be bound to the operation in the
836 // parent DagNode.
837 auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
838 int opArgIdx) {
839 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
840 if (DagNode subTree = tree.getArgAsNestedDag(index: i)) {
841 collectBoundSymbols(tree: subTree, infoMap, isSrcPattern);
842 } else {
843 auto argName = tree.getArgName(index: i);
844 if (!argName.empty() && argName != "_") {
845 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx),
846 symbolName: argName);
847 }
848 }
849 }
850 };
851
852 // The operand in `variadic` DAG should be bound to the operation in the
853 // parent DagNode. The range index must be included as well to distinguish
854 // (potentially) repeating argName within the `variadic` DAG.
855 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
856 int opArgIdx) {
857 auto treeName = tree.getSymbol();
858 if (!treeName.empty()) {
859 // If treeName is specified, bind to the full variadic operand_range.
860 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: treeName, op, argIndex: opArgIdx,
861 variadicSubIndex: std::nullopt),
862 symbolName: treeName);
863 }
864
865 for (int i = 0; i < tree.getNumArgs(); ++i) {
866 if (DagNode subTree = tree.getArgAsNestedDag(index: i)) {
867 collectBoundSymbols(tree: subTree, infoMap, isSrcPattern);
868 } else {
869 auto argName = tree.getArgName(index: i);
870 if (!argName.empty() && argName != "_") {
871 verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx,
872 /*variadicSubIndex=*/i),
873 symbolName: argName);
874 }
875 }
876 }
877 };
878
879 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
880 if (auto treeArg = tree.getArgAsNestedDag(index: i)) {
881 if (treeArg.isEither()) {
882 collectSymbolInEither(tree, treeArg, opArgIdx);
883 // `either` DAG is *flattened*. For example,
884 //
885 // (FooOp (either arg0, arg1), arg2)
886 //
887 // can be viewed as:
888 //
889 // (FooOp arg0, arg1, arg2)
890 ++opArgIdx;
891 } else if (treeArg.isVariadic()) {
892 collectSymbolInVariadic(tree, treeArg, opArgIdx);
893 } else {
894 // This DAG node argument is a DAG node itself. Go inside recursively.
895 collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern);
896 }
897 continue;
898 }
899
900 if (isSrcPattern) {
901 // We can only bind symbols to op arguments in source pattern. Those
902 // symbols are referenced in result patterns.
903 auto treeArgName = tree.getArgName(index: i);
904 // `$_` is a special symbol meaning ignore the current argument.
905 if (!treeArgName.empty() && treeArgName != "_") {
906 LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
907 << treeArgName << '\n');
908 verifyBind(result: infoMap.bindOpArgument(node: tree, symbol: treeArgName, op, argIndex: opArgIdx),
909 symbolName: treeArgName);
910 }
911 }
912 }
913 return;
914 }
915
916 if (!treeName.empty()) {
917 PrintFatalError(
918 Rec: &def, Msg: formatv(Fmt: "binding symbol '{0}' to non-operation/native code call "
919 "unsupported right now",
920 Vals&: treeName));
921 }
922}
923

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/TableGen/Pattern.cpp