1 | //===- Pattern.cpp - Pattern wrapper class --------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Pattern wrapper class to simplify using TableGen Record defining a MLIR |
10 | // Pattern. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include <utility> |
15 | |
16 | #include "mlir/TableGen/Pattern.h" |
17 | #include "llvm/ADT/StringExtras.h" |
18 | #include "llvm/ADT/Twine.h" |
19 | #include "llvm/Support/Debug.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "llvm/TableGen/Error.h" |
22 | #include "llvm/TableGen/Record.h" |
23 | |
24 | #define DEBUG_TYPE "mlir-tblgen-pattern" |
25 | |
26 | using namespace mlir; |
27 | using namespace tblgen; |
28 | |
29 | using llvm::DagInit; |
30 | using llvm::dbgs; |
31 | using llvm::DefInit; |
32 | using llvm::formatv; |
33 | using llvm::IntInit; |
34 | using llvm::Record; |
35 | |
36 | //===----------------------------------------------------------------------===// |
37 | // DagLeaf |
38 | //===----------------------------------------------------------------------===// |
39 | |
40 | bool DagLeaf::isUnspecified() const { |
41 | return isa_and_nonnull<llvm::UnsetInit>(Val: def); |
42 | } |
43 | |
44 | bool DagLeaf::isOperandMatcher() const { |
45 | // Operand matchers specify a type constraint. |
46 | return isSubClassOf(superclass: "TypeConstraint"); |
47 | } |
48 | |
49 | bool DagLeaf::isAttrMatcher() const { |
50 | // Attribute matchers specify an attribute constraint. |
51 | return isSubClassOf(superclass: "AttrConstraint"); |
52 | } |
53 | |
54 | bool DagLeaf::isNativeCodeCall() const { |
55 | return isSubClassOf(superclass: "NativeCodeCall"); |
56 | } |
57 | |
58 | bool DagLeaf::isConstantAttr() const { return isSubClassOf(superclass: "ConstantAttr"); } |
59 | |
60 | bool DagLeaf::isEnumCase() const { return isSubClassOf(superclass: "EnumCase"); } |
61 | |
62 | bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(Val: def); } |
63 | |
64 | Constraint DagLeaf::getAsConstraint() const { |
65 | assert((isOperandMatcher() || isAttrMatcher()) && |
66 | "the DAG leaf must be operand or attribute"); |
67 | return Constraint(cast<DefInit>(Val: def)->getDef()); |
68 | } |
69 | |
70 | ConstantAttr DagLeaf::getAsConstantAttr() const { |
71 | assert(isConstantAttr() && "the DAG leaf must be constant attribute"); |
72 | return ConstantAttr(cast<DefInit>(Val: def)); |
73 | } |
74 | |
75 | EnumCase DagLeaf::getAsEnumCase() const { |
76 | assert(isEnumCase() && "the DAG leaf must be an enum attribute case"); |
77 | return EnumCase(cast<DefInit>(Val: def)); |
78 | } |
79 | |
80 | std::string DagLeaf::getConditionTemplate() const { |
81 | return getAsConstraint().getConditionTemplate(); |
82 | } |
83 | |
84 | StringRef DagLeaf::getNativeCodeTemplate() const { |
85 | assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); |
86 | return cast<DefInit>(Val: def)->getDef()->getValueAsString(FieldName: "expression"); |
87 | } |
88 | |
89 | int DagLeaf::getNumReturnsOfNativeCode() const { |
90 | assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); |
91 | return cast<DefInit>(Val: def)->getDef()->getValueAsInt(FieldName: "numReturns"); |
92 | } |
93 | |
94 | std::string DagLeaf::getStringAttr() const { |
95 | assert(isStringAttr() && "the DAG leaf must be string attribute"); |
96 | return def->getAsUnquotedString(); |
97 | } |
98 | bool DagLeaf::isSubClassOf(StringRef superclass) const { |
99 | if (auto *defInit = dyn_cast_or_null<DefInit>(Val: def)) |
100 | return defInit->getDef()->isSubClassOf(Name: superclass); |
101 | return false; |
102 | } |
103 | |
104 | void DagLeaf::print(raw_ostream &os) const { |
105 | if (def) |
106 | def->print(OS&: os); |
107 | } |
108 | |
109 | //===----------------------------------------------------------------------===// |
110 | // DagNode |
111 | //===----------------------------------------------------------------------===// |
112 | |
113 | bool DagNode::isNativeCodeCall() const { |
114 | if (auto *defInit = dyn_cast_or_null<DefInit>(Val: node->getOperator())) |
115 | return defInit->getDef()->isSubClassOf(Name: "NativeCodeCall"); |
116 | return false; |
117 | } |
118 | |
119 | bool DagNode::isOperation() const { |
120 | return !isNativeCodeCall() && !isReplaceWithValue() && |
121 | !isLocationDirective() && !isReturnTypeDirective() && !isEither() && |
122 | !isVariadic(); |
123 | } |
124 | |
125 | StringRef DagNode::getNativeCodeTemplate() const { |
126 | assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); |
127 | return cast<DefInit>(Val: node->getOperator()) |
128 | ->getDef() |
129 | ->getValueAsString(FieldName: "expression"); |
130 | } |
131 | |
132 | int DagNode::getNumReturnsOfNativeCode() const { |
133 | assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); |
134 | return cast<DefInit>(Val: node->getOperator()) |
135 | ->getDef() |
136 | ->getValueAsInt(FieldName: "numReturns"); |
137 | } |
138 | |
139 | StringRef DagNode::getSymbol() const { return node->getNameStr(); } |
140 | |
141 | Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { |
142 | const Record *opDef = cast<DefInit>(Val: node->getOperator())->getDef(); |
143 | auto [it, inserted] = mapper->try_emplace(Key: opDef); |
144 | if (inserted) |
145 | it->second = std::make_unique<Operator>(args&: opDef); |
146 | return *it->second; |
147 | } |
148 | |
149 | int DagNode::getNumOps() const { |
150 | // We want to get number of operations recursively involved in the DAG tree. |
151 | // All other directives should be excluded. |
152 | int count = isOperation() ? 1 : 0; |
153 | for (int i = 0, e = getNumArgs(); i != e; ++i) { |
154 | if (auto child = getArgAsNestedDag(index: i)) |
155 | count += child.getNumOps(); |
156 | } |
157 | return count; |
158 | } |
159 | |
160 | int DagNode::getNumArgs() const { return node->getNumArgs(); } |
161 | |
162 | bool DagNode::isNestedDagArg(unsigned index) const { |
163 | return isa<DagInit>(Val: node->getArg(Num: index)); |
164 | } |
165 | |
166 | DagNode DagNode::getArgAsNestedDag(unsigned index) const { |
167 | return DagNode(dyn_cast_or_null<DagInit>(Val: node->getArg(Num: index))); |
168 | } |
169 | |
170 | DagLeaf DagNode::getArgAsLeaf(unsigned index) const { |
171 | assert(!isNestedDagArg(index)); |
172 | return DagLeaf(node->getArg(Num: index)); |
173 | } |
174 | |
175 | StringRef DagNode::getArgName(unsigned index) const { |
176 | return node->getArgNameStr(Num: index); |
177 | } |
178 | |
179 | bool DagNode::isReplaceWithValue() const { |
180 | auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef(); |
181 | return dagOpDef->getName() == "replaceWithValue"; |
182 | } |
183 | |
184 | bool DagNode::isLocationDirective() const { |
185 | auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef(); |
186 | return dagOpDef->getName() == "location"; |
187 | } |
188 | |
189 | bool DagNode::isReturnTypeDirective() const { |
190 | auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef(); |
191 | return dagOpDef->getName() == "returnType"; |
192 | } |
193 | |
194 | bool DagNode::isEither() const { |
195 | auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef(); |
196 | return dagOpDef->getName() == "either"; |
197 | } |
198 | |
199 | bool DagNode::isVariadic() const { |
200 | auto *dagOpDef = cast<DefInit>(Val: node->getOperator())->getDef(); |
201 | return dagOpDef->getName() == "variadic"; |
202 | } |
203 | |
204 | void DagNode::print(raw_ostream &os) const { |
205 | if (node) |
206 | node->print(OS&: os); |
207 | } |
208 | |
209 | //===----------------------------------------------------------------------===// |
210 | // SymbolInfoMap |
211 | //===----------------------------------------------------------------------===// |
212 | |
213 | StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { |
214 | int idx = -1; |
215 | auto [name, indexStr] = symbol.rsplit(Separator: "__"); |
216 | |
217 | if (indexStr.consumeInteger(Radix: 10, Result&: idx)) { |
218 | // The second part is not an index; we return the whole symbol as-is. |
219 | return symbol; |
220 | } |
221 | if (index) { |
222 | *index = idx; |
223 | } |
224 | return name; |
225 | } |
226 | |
227 | SymbolInfoMap::SymbolInfo::SymbolInfo( |
228 | const Operator *op, SymbolInfo::Kind kind, |
229 | std::optional<DagAndConstant> dagAndConstant) |
230 | : op(op), kind(kind), dagAndConstant(dagAndConstant) {} |
231 | |
232 | int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { |
233 | switch (kind) { |
234 | case Kind::Attr: |
235 | case Kind::Operand: |
236 | case Kind::Value: |
237 | return 1; |
238 | case Kind::Result: |
239 | return op->getNumResults(); |
240 | case Kind::MultipleValues: |
241 | return getSize(); |
242 | } |
243 | llvm_unreachable("unknown kind"); |
244 | } |
245 | |
246 | std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { |
247 | return alternativeName ? *alternativeName : name.str(); |
248 | } |
249 | |
250 | std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { |
251 | LLVM_DEBUG(dbgs() << "getVarTypeStr for '"<< name << "': "); |
252 | switch (kind) { |
253 | case Kind::Attr: { |
254 | if (op) |
255 | return cast<NamedAttribute *>(Val: op->getArg(index: getArgIndex())) |
256 | ->attr.getStorageType() |
257 | .str(); |
258 | // TODO(suderman): Use a more exact type when available. |
259 | return "::mlir::Attribute"; |
260 | } |
261 | case Kind::Operand: { |
262 | // Use operand range for captured operands (to support potential variadic |
263 | // operands). |
264 | return "::mlir::Operation::operand_range"; |
265 | } |
266 | case Kind::Value: { |
267 | return "::mlir::Value"; |
268 | } |
269 | case Kind::MultipleValues: { |
270 | return "::mlir::ValueRange"; |
271 | } |
272 | case Kind::Result: { |
273 | // Use the op itself for captured results. |
274 | return op->getQualCppClassName(); |
275 | } |
276 | } |
277 | llvm_unreachable("unknown kind"); |
278 | } |
279 | |
280 | std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { |
281 | LLVM_DEBUG(dbgs() << "getVarDecl for '"<< name << "': "); |
282 | std::string varInit = kind == Kind::Operand ? "(op0->getOperands())": ""; |
283 | return std::string( |
284 | formatv(Fmt: "{0} {1}{2};\n", Vals: getVarTypeStr(name), Vals: getVarName(name), Vals&: varInit)); |
285 | } |
286 | |
287 | std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { |
288 | LLVM_DEBUG(dbgs() << "getArgDecl for '"<< name << "': "); |
289 | return std::string( |
290 | formatv(Fmt: "{0} &{1}", Vals: getVarTypeStr(name), Vals: getVarName(name))); |
291 | } |
292 | |
293 | std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( |
294 | StringRef name, int index, const char *fmt, const char *separator) const { |
295 | LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '"<< name << "': "); |
296 | switch (kind) { |
297 | case Kind::Attr: { |
298 | assert(index < 0); |
299 | auto repl = formatv(Fmt: fmt, Vals&: name); |
300 | LLVM_DEBUG(dbgs() << repl << " (Attr)\n"); |
301 | return std::string(repl); |
302 | } |
303 | case Kind::Operand: { |
304 | assert(index < 0); |
305 | auto *operand = cast<NamedTypeConstraint *>(Val: op->getArg(index: getArgIndex())); |
306 | if (operand->isOptional()) { |
307 | auto repl = formatv( |
308 | Fmt: fmt, Vals: formatv(Fmt: "({0}.empty() ? ::mlir::Value() : *{0}.begin())", Vals&: name)); |
309 | LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n"); |
310 | return std::string(repl); |
311 | } |
312 | // If this operand is variadic and this SymbolInfo doesn't have a range |
313 | // index, then return the full variadic operand_range. Otherwise, return |
314 | // the value itself. |
315 | if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) { |
316 | auto repl = formatv(Fmt: fmt, Vals&: name); |
317 | LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n"); |
318 | return std::string(repl); |
319 | } |
320 | auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "(*{0}.begin())", Vals&: name)); |
321 | LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n"); |
322 | return std::string(repl); |
323 | } |
324 | case Kind::Result: { |
325 | // If `index` is greater than zero, then we are referencing a specific |
326 | // result of a multi-result op. The result can still be variadic. |
327 | if (index >= 0) { |
328 | std::string v = |
329 | std::string(formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: index)); |
330 | if (!op->getResult(index).isVariadic()) |
331 | v = std::string(formatv(Fmt: "(*{0}.begin())", Vals&: v)); |
332 | auto repl = formatv(Fmt: fmt, Vals&: v); |
333 | LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n"); |
334 | return std::string(repl); |
335 | } |
336 | |
337 | // If this op has no result at all but still we bind a symbol to it, it |
338 | // means we want to capture the op itself. |
339 | if (op->getNumResults() == 0) { |
340 | LLVM_DEBUG(dbgs() << name << " (Op)\n"); |
341 | return formatv(Fmt: fmt, Vals&: name); |
342 | } |
343 | |
344 | // We are referencing all results of the multi-result op. A specific result |
345 | // can either be a value or a range. Then join them with `separator`. |
346 | SmallVector<std::string, 4> values; |
347 | values.reserve(N: op->getNumResults()); |
348 | |
349 | for (int i = 0, e = op->getNumResults(); i < e; ++i) { |
350 | std::string v = std::string(formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: i)); |
351 | if (!op->getResult(index: i).isVariadic()) { |
352 | v = std::string(formatv(Fmt: "(*{0}.begin())", Vals&: v)); |
353 | } |
354 | values.push_back(Elt: std::string(formatv(Fmt: fmt, Vals&: v))); |
355 | } |
356 | auto repl = llvm::join(R&: values, Separator: separator); |
357 | LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n"); |
358 | return repl; |
359 | } |
360 | case Kind::Value: { |
361 | assert(index < 0); |
362 | assert(op == nullptr); |
363 | auto repl = formatv(Fmt: fmt, Vals&: name); |
364 | LLVM_DEBUG(dbgs() << repl << " (Value)\n"); |
365 | return std::string(repl); |
366 | } |
367 | case Kind::MultipleValues: { |
368 | assert(op == nullptr); |
369 | assert(index < getSize()); |
370 | if (index >= 0) { |
371 | std::string repl = |
372 | formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]", Vals&: name, Vals&: index))); |
373 | LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); |
374 | return repl; |
375 | } |
376 | // If it doesn't specify certain element, unpack them all. |
377 | auto repl = |
378 | formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()", Vals&: name))); |
379 | LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); |
380 | return std::string(repl); |
381 | } |
382 | } |
383 | llvm_unreachable("unknown kind"); |
384 | } |
385 | |
386 | std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( |
387 | StringRef name, int index, const char *fmt, const char *separator) const { |
388 | LLVM_DEBUG(dbgs() << "getAllRangeUse for '"<< name << "': "); |
389 | switch (kind) { |
390 | case Kind::Attr: |
391 | case Kind::Operand: { |
392 | assert(index < 0 && "only allowed for symbol bound to result"); |
393 | auto repl = formatv(Fmt: fmt, Vals&: name); |
394 | LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n"); |
395 | return std::string(repl); |
396 | } |
397 | case Kind::Result: { |
398 | if (index >= 0) { |
399 | auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: index)); |
400 | LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n"); |
401 | return std::string(repl); |
402 | } |
403 | |
404 | // We are referencing all results of the multi-result op. Each result should |
405 | // have a value range, and then join them with `separator`. |
406 | SmallVector<std::string, 4> values; |
407 | values.reserve(N: op->getNumResults()); |
408 | |
409 | for (int i = 0, e = op->getNumResults(); i < e; ++i) { |
410 | values.push_back(Elt: std::string( |
411 | formatv(Fmt: fmt, Vals: formatv(Fmt: "{0}.getODSResults({1})", Vals&: name, Vals&: i)))); |
412 | } |
413 | auto repl = llvm::join(R&: values, Separator: separator); |
414 | LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n"); |
415 | return repl; |
416 | } |
417 | case Kind::Value: { |
418 | assert(index < 0 && "only allowed for symbol bound to result"); |
419 | assert(op == nullptr); |
420 | auto repl = formatv(Fmt: fmt, Vals: formatv(Fmt: "{{{0}}", Vals&: name)); |
421 | LLVM_DEBUG(dbgs() << repl << " (Value)\n"); |
422 | return std::string(repl); |
423 | } |
424 | case Kind::MultipleValues: { |
425 | assert(op == nullptr); |
426 | assert(index < getSize()); |
427 | if (index >= 0) { |
428 | std::string repl = |
429 | formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}[{1}]", Vals&: name, Vals&: index))); |
430 | LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); |
431 | return repl; |
432 | } |
433 | auto repl = |
434 | formatv(Fmt: fmt, Vals: std::string(formatv(Fmt: "{0}.begin(), {0}.end()", Vals&: name))); |
435 | LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); |
436 | return std::string(repl); |
437 | } |
438 | } |
439 | llvm_unreachable("unknown kind"); |
440 | } |
441 | |
442 | bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, |
443 | const Operator &op, int argIndex, |
444 | std::optional<int> variadicSubIndex) { |
445 | StringRef name = getValuePackName(symbol); |
446 | if (name != symbol) { |
447 | auto error = formatv( |
448 | Fmt: "symbol '{0}' with trailing index cannot bind to op argument", Vals&: symbol); |
449 | PrintFatalError(ErrorLoc: loc, Msg: error); |
450 | } |
451 | |
452 | auto symInfo = |
453 | isa<NamedAttribute *>(Val: op.getArg(index: argIndex)) |
454 | ? SymbolInfo::getAttr(op: &op, index: argIndex) |
455 | : SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex); |
456 | |
457 | std::string key = symbol.str(); |
458 | if (symbolInfoMap.count(x: key)) { |
459 | // Only non unique name for the operand is supported. |
460 | if (symInfo.kind != SymbolInfo::Kind::Operand) { |
461 | return false; |
462 | } |
463 | |
464 | // Cannot add new operand if there is already non operand with the same |
465 | // name. |
466 | if (symbolInfoMap.find(x: key)->second.kind != SymbolInfo::Kind::Operand) { |
467 | return false; |
468 | } |
469 | } |
470 | |
471 | symbolInfoMap.emplace(args&: key, args&: symInfo); |
472 | return true; |
473 | } |
474 | |
475 | bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { |
476 | std::string name = getValuePackName(symbol).str(); |
477 | auto inserted = symbolInfoMap.emplace(args&: name, args: SymbolInfo::getResult(op: &op)); |
478 | |
479 | return symbolInfoMap.count(x: inserted->first) == 1; |
480 | } |
481 | |
482 | bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { |
483 | std::string name = getValuePackName(symbol).str(); |
484 | if (numValues > 1) |
485 | return bindMultipleValues(symbol: name, numValues); |
486 | return bindValue(symbol: name); |
487 | } |
488 | |
489 | bool SymbolInfoMap::bindValue(StringRef symbol) { |
490 | auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getValue()); |
491 | return symbolInfoMap.count(x: inserted->first) == 1; |
492 | } |
493 | |
494 | bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { |
495 | std::string name = getValuePackName(symbol).str(); |
496 | auto inserted = |
497 | symbolInfoMap.emplace(args&: name, args: SymbolInfo::getMultipleValues(numValues)); |
498 | return symbolInfoMap.count(x: inserted->first) == 1; |
499 | } |
500 | |
501 | bool SymbolInfoMap::bindAttr(StringRef symbol) { |
502 | auto inserted = symbolInfoMap.emplace(args: symbol.str(), args: SymbolInfo::getAttr()); |
503 | return symbolInfoMap.count(x: inserted->first) == 1; |
504 | } |
505 | |
506 | bool SymbolInfoMap::contains(StringRef symbol) const { |
507 | return find(key: symbol) != symbolInfoMap.end(); |
508 | } |
509 | |
510 | SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { |
511 | std::string name = getValuePackName(symbol: key).str(); |
512 | |
513 | return symbolInfoMap.find(x: name); |
514 | } |
515 | |
516 | SymbolInfoMap::const_iterator |
517 | SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, |
518 | int argIndex, |
519 | std::optional<int> variadicSubIndex) const { |
520 | return findBoundSymbol( |
521 | key, symbolInfo: SymbolInfo::getOperand(node, op: &op, operandIndex: argIndex, variadicSubIndex)); |
522 | } |
523 | |
524 | SymbolInfoMap::const_iterator |
525 | SymbolInfoMap::findBoundSymbol(StringRef key, |
526 | const SymbolInfo &symbolInfo) const { |
527 | std::string name = getValuePackName(symbol: key).str(); |
528 | auto range = symbolInfoMap.equal_range(x: name); |
529 | |
530 | for (auto it = range.first; it != range.second; ++it) |
531 | if (it->second.dagAndConstant == symbolInfo.dagAndConstant) |
532 | return it; |
533 | |
534 | return symbolInfoMap.end(); |
535 | } |
536 | |
537 | std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> |
538 | SymbolInfoMap::getRangeOfEqualElements(StringRef key) { |
539 | std::string name = getValuePackName(symbol: key).str(); |
540 | |
541 | return symbolInfoMap.equal_range(x: name); |
542 | } |
543 | |
544 | int SymbolInfoMap::count(StringRef key) const { |
545 | std::string name = getValuePackName(symbol: key).str(); |
546 | return symbolInfoMap.count(x: name); |
547 | } |
548 | |
549 | int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { |
550 | StringRef name = getValuePackName(symbol); |
551 | if (name != symbol) { |
552 | // If there is a trailing index inside symbol, it references just one |
553 | // static value. |
554 | return 1; |
555 | } |
556 | // Otherwise, find how many it represents by querying the symbol's info. |
557 | return find(key: name)->second.getStaticValueCount(); |
558 | } |
559 | |
560 | std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, |
561 | const char *fmt, |
562 | const char *separator) const { |
563 | int index = -1; |
564 | StringRef name = getValuePackName(symbol, index: &index); |
565 | |
566 | auto it = symbolInfoMap.find(x: name.str()); |
567 | if (it == symbolInfoMap.end()) { |
568 | auto error = formatv(Fmt: "referencing unbound symbol '{0}'", Vals&: symbol); |
569 | PrintFatalError(ErrorLoc: loc, Msg: error); |
570 | } |
571 | |
572 | return it->second.getValueAndRangeUse(name, index, fmt, separator); |
573 | } |
574 | |
575 | std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, |
576 | const char *separator) const { |
577 | int index = -1; |
578 | StringRef name = getValuePackName(symbol, index: &index); |
579 | |
580 | auto it = symbolInfoMap.find(x: name.str()); |
581 | if (it == symbolInfoMap.end()) { |
582 | auto error = formatv(Fmt: "referencing unbound symbol '{0}'", Vals&: symbol); |
583 | PrintFatalError(ErrorLoc: loc, Msg: error); |
584 | } |
585 | |
586 | return it->second.getAllRangeUse(name, index, fmt, separator); |
587 | } |
588 | |
589 | void SymbolInfoMap::assignUniqueAlternativeNames() { |
590 | llvm::StringSet<> usedNames; |
591 | |
592 | for (auto symbolInfoIt = symbolInfoMap.begin(); |
593 | symbolInfoIt != symbolInfoMap.end();) { |
594 | auto range = symbolInfoMap.equal_range(x: symbolInfoIt->first); |
595 | auto startRange = range.first; |
596 | auto endRange = range.second; |
597 | |
598 | auto operandName = symbolInfoIt->first; |
599 | int startSearchIndex = 0; |
600 | for (++startRange; startRange != endRange; ++startRange) { |
601 | // Current operand name is not unique, find a unique one |
602 | // and set the alternative name. |
603 | for (int i = startSearchIndex;; ++i) { |
604 | std::string alternativeName = operandName + std::to_string(val: i); |
605 | if (!usedNames.contains(key: alternativeName) && |
606 | symbolInfoMap.count(x: alternativeName) == 0) { |
607 | usedNames.insert(key: alternativeName); |
608 | startRange->second.alternativeName = alternativeName; |
609 | startSearchIndex = i + 1; |
610 | |
611 | break; |
612 | } |
613 | } |
614 | } |
615 | |
616 | symbolInfoIt = endRange; |
617 | } |
618 | } |
619 | |
620 | //===----------------------------------------------------------------------===// |
621 | // Pattern |
622 | //==----------------------------------------------------------------------===// |
623 | |
624 | Pattern::Pattern(const Record *def, RecordOperatorMap *mapper) |
625 | : def(*def), recordOpMap(mapper) {} |
626 | |
627 | DagNode Pattern::getSourcePattern() const { |
628 | return DagNode(def.getValueAsDag(FieldName: "sourcePattern")); |
629 | } |
630 | |
631 | int Pattern::getNumResultPatterns() const { |
632 | auto *results = def.getValueAsListInit(FieldName: "resultPatterns"); |
633 | return results->size(); |
634 | } |
635 | |
636 | DagNode Pattern::getResultPattern(unsigned index) const { |
637 | auto *results = def.getValueAsListInit(FieldName: "resultPatterns"); |
638 | return DagNode(cast<DagInit>(Val: results->getElement(Idx: index))); |
639 | } |
640 | |
641 | void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { |
642 | LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n"); |
643 | collectBoundSymbols(tree: getSourcePattern(), infoMap, /*isSrcPattern=*/true); |
644 | LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n"); |
645 | |
646 | LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n"); |
647 | infoMap.assignUniqueAlternativeNames(); |
648 | LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n"); |
649 | } |
650 | |
651 | void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { |
652 | LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n"); |
653 | for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { |
654 | auto pattern = getResultPattern(index: i); |
655 | collectBoundSymbols(tree: pattern, infoMap, /*isSrcPattern=*/false); |
656 | } |
657 | LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n"); |
658 | } |
659 | |
660 | const Operator &Pattern::getSourceRootOp() { |
661 | return getSourcePattern().getDialectOp(mapper: recordOpMap); |
662 | } |
663 | |
664 | Operator &Pattern::getDialectOp(DagNode node) { |
665 | return node.getDialectOp(mapper: recordOpMap); |
666 | } |
667 | |
668 | std::vector<AppliedConstraint> Pattern::getConstraints() const { |
669 | auto *listInit = def.getValueAsListInit(FieldName: "constraints"); |
670 | std::vector<AppliedConstraint> ret; |
671 | ret.reserve(n: listInit->size()); |
672 | |
673 | for (auto *it : *listInit) { |
674 | auto *dagInit = dyn_cast<DagInit>(Val: it); |
675 | if (!dagInit) |
676 | PrintFatalError(Rec: &def, Msg: "all elements in Pattern multi-entity " |
677 | "constraints should be DAG nodes"); |
678 | |
679 | std::vector<std::string> entities; |
680 | entities.reserve(n: dagInit->arg_size()); |
681 | for (auto *argName : dagInit->getArgNames()) { |
682 | if (!argName) { |
683 | PrintFatalError( |
684 | Rec: &def, |
685 | Msg: "operands to additional constraints can only be symbol references"); |
686 | } |
687 | entities.emplace_back(args: argName->getValue()); |
688 | } |
689 | |
690 | ret.emplace_back(args: cast<DefInit>(Val: dagInit->getOperator())->getDef(), |
691 | args: dagInit->getNameStr(), args: std::move(entities)); |
692 | } |
693 | return ret; |
694 | } |
695 | |
696 | int Pattern::getNumSupplementalPatterns() const { |
697 | auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns"); |
698 | return results->size(); |
699 | } |
700 | |
701 | DagNode Pattern::getSupplementalPattern(unsigned index) const { |
702 | auto *results = def.getValueAsListInit(FieldName: "supplementalPatterns"); |
703 | return DagNode(cast<DagInit>(Val: results->getElement(Idx: index))); |
704 | } |
705 | |
706 | int Pattern::getBenefit() const { |
707 | // The initial benefit value is a heuristic with number of ops in the source |
708 | // pattern. |
709 | int initBenefit = getSourcePattern().getNumOps(); |
710 | const DagInit *delta = def.getValueAsDag(FieldName: "benefitDelta"); |
711 | if (delta->getNumArgs() != 1 || !isa<IntInit>(Val: delta->getArg(Num: 0))) { |
712 | PrintFatalError(Rec: &def, |
713 | Msg: "The 'addBenefit' takes and only takes one integer value"); |
714 | } |
715 | return initBenefit + dyn_cast<IntInit>(Val: delta->getArg(Num: 0))->getValue(); |
716 | } |
717 | |
718 | std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { |
719 | std::vector<std::pair<StringRef, unsigned>> result; |
720 | result.reserve(n: def.getLoc().size()); |
721 | for (auto loc : def.getLoc()) { |
722 | unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(Loc: loc); |
723 | assert(buf && "invalid source location"); |
724 | result.emplace_back( |
725 | args: llvm::SrcMgr.getBufferInfo(i: buf).Buffer->getBufferIdentifier(), |
726 | args: llvm::SrcMgr.getLineAndColumn(Loc: loc, BufferID: buf).first); |
727 | } |
728 | return result; |
729 | } |
730 | |
731 | void Pattern::verifyBind(bool result, StringRef symbolName) { |
732 | if (!result) { |
733 | auto err = formatv(Fmt: "symbol '{0}' bound more than once", Vals&: symbolName); |
734 | PrintFatalError(Rec: &def, Msg: err); |
735 | } |
736 | } |
737 | |
738 | void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, |
739 | bool isSrcPattern) { |
740 | auto treeName = tree.getSymbol(); |
741 | auto numTreeArgs = tree.getNumArgs(); |
742 | |
743 | if (tree.isNativeCodeCall()) { |
744 | if (!treeName.empty()) { |
745 | if (!isSrcPattern) { |
746 | LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: " |
747 | << treeName << '\n'); |
748 | verifyBind( |
749 | result: infoMap.bindValues(symbol: treeName, numValues: tree.getNumReturnsOfNativeCode()), |
750 | symbolName: treeName); |
751 | } else { |
752 | PrintFatalError(Rec: &def, |
753 | Msg: formatv(Fmt: "binding symbol '{0}' to NativecodeCall in " |
754 | "MatchPattern is not supported", |
755 | Vals&: treeName)); |
756 | } |
757 | } |
758 | |
759 | for (int i = 0; i != numTreeArgs; ++i) { |
760 | if (auto treeArg = tree.getArgAsNestedDag(index: i)) { |
761 | // This DAG node argument is a DAG node itself. Go inside recursively. |
762 | collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern); |
763 | continue; |
764 | } |
765 | |
766 | if (!isSrcPattern) |
767 | continue; |
768 | |
769 | // We can only bind symbols to arguments in source pattern. Those |
770 | // symbols are referenced in result patterns. |
771 | auto treeArgName = tree.getArgName(index: i); |
772 | |
773 | // `$_` is a special symbol meaning ignore the current argument. |
774 | if (!treeArgName.empty() && treeArgName != "_") { |
775 | DagLeaf leaf = tree.getArgAsLeaf(index: i); |
776 | |
777 | // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), |
778 | if (leaf.isUnspecified()) { |
779 | // This is case of $c, a Value without any constraints. |
780 | verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName); |
781 | } else { |
782 | auto constraint = leaf.getAsConstraint(); |
783 | bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() || |
784 | leaf.isConstantAttr() || |
785 | constraint.getKind() == Constraint::Kind::CK_Attr; |
786 | |
787 | if (isAttr) { |
788 | // This is case of $a, a binding to a certain attribute. |
789 | verifyBind(result: infoMap.bindAttr(symbol: treeArgName), symbolName: treeArgName); |
790 | continue; |
791 | } |
792 | |
793 | // This is case of $b, a binding to a certain type. |
794 | verifyBind(result: infoMap.bindValue(symbol: treeArgName), symbolName: treeArgName); |
795 | } |
796 | } |
797 | } |
798 | |
799 | return; |
800 | } |
801 | |
802 | if (tree.isOperation()) { |
803 | auto &op = getDialectOp(node: tree); |
804 | auto numOpArgs = op.getNumArgs(); |
805 | int numEither = 0; |
806 | |
807 | // We need to exclude the trailing directives and `either` directive groups |
808 | // two operands of the operation. |
809 | int numDirectives = 0; |
810 | for (int i = numTreeArgs - 1; i >= 0; --i) { |
811 | if (auto dagArg = tree.getArgAsNestedDag(index: i)) { |
812 | if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) |
813 | ++numDirectives; |
814 | else if (dagArg.isEither()) |
815 | ++numEither; |
816 | } |
817 | } |
818 | |
819 | if (numOpArgs != numTreeArgs - numDirectives + numEither) { |
820 | auto err = |
821 | formatv(Fmt: "op '{0}' argument number mismatch: " |
822 | "{1} in pattern vs. {2} in definition", |
823 | Vals: op.getOperationName(), Vals: numTreeArgs + numEither, Vals&: numOpArgs); |
824 | PrintFatalError(Rec: &def, Msg: err); |
825 | } |
826 | |
827 | // The name attached to the DAG node's operator is for representing the |
828 | // results generated from this op. It should be remembered as bound results. |
829 | if (!treeName.empty()) { |
830 | LLVM_DEBUG(dbgs() << "found symbol bound to op result: "<< treeName |
831 | << '\n'); |
832 | verifyBind(result: infoMap.bindOpResult(symbol: treeName, op), symbolName: treeName); |
833 | } |
834 | |
835 | // The operand in `either` DAG should be bound to the operation in the |
836 | // parent DagNode. |
837 | auto collectSymbolInEither = [&](DagNode parent, DagNode tree, |
838 | int opArgIdx) { |
839 | for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) { |
840 | if (DagNode subTree = tree.getArgAsNestedDag(index: i)) { |
841 | collectBoundSymbols(tree: subTree, infoMap, isSrcPattern); |
842 | } else { |
843 | auto argName = tree.getArgName(index: i); |
844 | if (!argName.empty() && argName != "_") { |
845 | verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx), |
846 | symbolName: argName); |
847 | } |
848 | } |
849 | } |
850 | }; |
851 | |
852 | // The operand in `variadic` DAG should be bound to the operation in the |
853 | // parent DagNode. The range index must be included as well to distinguish |
854 | // (potentially) repeating argName within the `variadic` DAG. |
855 | auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree, |
856 | int opArgIdx) { |
857 | auto treeName = tree.getSymbol(); |
858 | if (!treeName.empty()) { |
859 | // If treeName is specified, bind to the full variadic operand_range. |
860 | verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: treeName, op, argIndex: opArgIdx, |
861 | variadicSubIndex: std::nullopt), |
862 | symbolName: treeName); |
863 | } |
864 | |
865 | for (int i = 0; i < tree.getNumArgs(); ++i) { |
866 | if (DagNode subTree = tree.getArgAsNestedDag(index: i)) { |
867 | collectBoundSymbols(tree: subTree, infoMap, isSrcPattern); |
868 | } else { |
869 | auto argName = tree.getArgName(index: i); |
870 | if (!argName.empty() && argName != "_") { |
871 | verifyBind(result: infoMap.bindOpArgument(node: parent, symbol: argName, op, argIndex: opArgIdx, |
872 | /*variadicSubIndex=*/i), |
873 | symbolName: argName); |
874 | } |
875 | } |
876 | } |
877 | }; |
878 | |
879 | for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) { |
880 | if (auto treeArg = tree.getArgAsNestedDag(index: i)) { |
881 | if (treeArg.isEither()) { |
882 | collectSymbolInEither(tree, treeArg, opArgIdx); |
883 | // `either` DAG is *flattened*. For example, |
884 | // |
885 | // (FooOp (either arg0, arg1), arg2) |
886 | // |
887 | // can be viewed as: |
888 | // |
889 | // (FooOp arg0, arg1, arg2) |
890 | ++opArgIdx; |
891 | } else if (treeArg.isVariadic()) { |
892 | collectSymbolInVariadic(tree, treeArg, opArgIdx); |
893 | } else { |
894 | // This DAG node argument is a DAG node itself. Go inside recursively. |
895 | collectBoundSymbols(tree: treeArg, infoMap, isSrcPattern); |
896 | } |
897 | continue; |
898 | } |
899 | |
900 | if (isSrcPattern) { |
901 | // We can only bind symbols to op arguments in source pattern. Those |
902 | // symbols are referenced in result patterns. |
903 | auto treeArgName = tree.getArgName(index: i); |
904 | // `$_` is a special symbol meaning ignore the current argument. |
905 | if (!treeArgName.empty() && treeArgName != "_") { |
906 | LLVM_DEBUG(dbgs() << "found symbol bound to op argument: " |
907 | << treeArgName << '\n'); |
908 | verifyBind(result: infoMap.bindOpArgument(node: tree, symbol: treeArgName, op, argIndex: opArgIdx), |
909 | symbolName: treeArgName); |
910 | } |
911 | } |
912 | } |
913 | return; |
914 | } |
915 | |
916 | if (!treeName.empty()) { |
917 | PrintFatalError( |
918 | Rec: &def, Msg: formatv(Fmt: "binding symbol '{0}' to non-operation/native code call " |
919 | "unsupported right now", |
920 | Vals&: treeName)); |
921 | } |
922 | } |
923 |
Definitions
- isUnspecified
- isOperandMatcher
- isAttrMatcher
- isNativeCodeCall
- isConstantAttr
- isEnumCase
- isStringAttr
- getAsConstraint
- getAsConstantAttr
- getAsEnumCase
- getConditionTemplate
- getNativeCodeTemplate
- getNumReturnsOfNativeCode
- getStringAttr
- isSubClassOf
- isNativeCodeCall
- isOperation
- getNativeCodeTemplate
- getNumReturnsOfNativeCode
- getSymbol
- getDialectOp
- getNumOps
- getNumArgs
- isNestedDagArg
- getArgAsNestedDag
- getArgAsLeaf
- getArgName
- isReplaceWithValue
- isLocationDirective
- isReturnTypeDirective
- isEither
- isVariadic
- getValuePackName
- SymbolInfo
- getStaticValueCount
- getVarName
- getVarTypeStr
- getVarDecl
- getArgDecl
- getValueAndRangeUse
- getAllRangeUse
- bindOpArgument
- bindOpResult
- bindValues
- bindValue
- bindMultipleValues
- bindAttr
- contains
- find
- findBoundSymbol
- findBoundSymbol
- getRangeOfEqualElements
- count
- getStaticValueCount
- getValueAndRangeUse
- getAllRangeUse
- assignUniqueAlternativeNames
- Pattern
- getSourcePattern
- getNumResultPatterns
- getResultPattern
- collectSourcePatternBoundSymbols
- collectResultPatternBoundSymbols
- getSourceRootOp
- getDialectOp
- getConstraints
- getNumSupplementalPatterns
- getSupplementalPattern
- getBenefit
- getLocation
- verifyBind
Improve your Profiling and Debugging skills
Find out more