1 | //===- Constraint.cpp - Constraint class ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Constraint wrapper to simplify using TableGen Record for constraints. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/TableGen/Constraint.h" |
14 | #include "llvm/TableGen/Record.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::tblgen; |
18 | |
19 | Constraint::Constraint(const llvm::Record *record) |
20 | : Constraint(record, CK_Uncategorized) { |
21 | // Look through OpVariable's to their constraint. |
22 | if (def->isSubClassOf(Name: "OpVariable" )) |
23 | def = def->getValueAsDef(FieldName: "constraint" ); |
24 | |
25 | if (def->isSubClassOf(Name: "TypeConstraint" )) { |
26 | kind = CK_Type; |
27 | } else if (def->isSubClassOf(Name: "AttrConstraint" )) { |
28 | kind = CK_Attr; |
29 | } else if (def->isSubClassOf(Name: "PropConstraint" )) { |
30 | kind = CK_Prop; |
31 | } else if (def->isSubClassOf(Name: "RegionConstraint" )) { |
32 | kind = CK_Region; |
33 | } else if (def->isSubClassOf(Name: "SuccessorConstraint" )) { |
34 | kind = CK_Successor; |
35 | } else if (!def->isSubClassOf(Name: "Constraint" )) { |
36 | llvm::errs() << "Expected a constraint but got: \n" << *def << "\n" ; |
37 | llvm::report_fatal_error(reason: "Abort" ); |
38 | } |
39 | } |
40 | |
41 | Pred Constraint::getPredicate() const { |
42 | auto *val = def->getValue(Name: "predicate" ); |
43 | |
44 | // If no predicate is specified, then return the null predicate (which |
45 | // corresponds to true). |
46 | if (!val) |
47 | return Pred(); |
48 | |
49 | const auto *pred = dyn_cast<llvm::DefInit>(Val: val->getValue()); |
50 | return Pred(pred); |
51 | } |
52 | |
53 | std::string Constraint::getConditionTemplate() const { |
54 | return getPredicate().getCondition(); |
55 | } |
56 | |
57 | StringRef Constraint::getSummary() const { |
58 | if (std::optional<StringRef> summary = |
59 | def->getValueAsOptionalString(FieldName: "summary" )) |
60 | return *summary; |
61 | return def->getName(); |
62 | } |
63 | |
64 | StringRef Constraint::getDescription() const { |
65 | return def->getValueAsOptionalString(FieldName: "description" ).value_or(u: "" ); |
66 | } |
67 | |
68 | StringRef Constraint::getDefName() const { |
69 | if (std::optional<StringRef> baseDefName = getBaseDefName()) |
70 | return *baseDefName; |
71 | return def->getName(); |
72 | } |
73 | |
74 | std::string Constraint::getUniqueDefName() const { |
75 | std::string defName = def->getName().str(); |
76 | |
77 | // Non-anonymous classes already have a unique name from the def. |
78 | if (!def->isAnonymous()) |
79 | return defName; |
80 | |
81 | // Otherwise, this is an anonymous class. In these cases we still use the def |
82 | // name, but we also try attach the name of the base def when present to make |
83 | // the name more obvious. |
84 | if (std::optional<StringRef> baseDefName = getBaseDefName()) |
85 | return (*baseDefName + "(" + defName + ")" ).str(); |
86 | return defName; |
87 | } |
88 | |
89 | std::optional<StringRef> Constraint::getBaseDefName() const { |
90 | // Functor used to check a base def in the case where the current def is |
91 | // anonymous. |
92 | auto checkBaseDefFn = [&](StringRef baseName) -> std::optional<StringRef> { |
93 | if (const auto *defValue = def->getValue(Name: baseName)) { |
94 | if (const auto *defInit = dyn_cast<llvm::DefInit>(Val: defValue->getValue())) |
95 | return Constraint(defInit->getDef(), kind).getDefName(); |
96 | } |
97 | return std::nullopt; |
98 | }; |
99 | |
100 | switch (kind) { |
101 | case CK_Attr: |
102 | if (def->isAnonymous()) |
103 | return checkBaseDefFn("baseAttr" ); |
104 | return std::nullopt; |
105 | case CK_Type: |
106 | if (def->isAnonymous()) |
107 | return checkBaseDefFn("baseType" ); |
108 | return std::nullopt; |
109 | default: |
110 | return std::nullopt; |
111 | } |
112 | } |
113 | |
114 | std::optional<StringRef> Constraint::getCppFunctionName() const { |
115 | std::optional<StringRef> name = |
116 | def->getValueAsOptionalString(FieldName: "cppFunctionName" ); |
117 | if (!name || *name == "" ) |
118 | return std::nullopt; |
119 | return name; |
120 | } |
121 | |
122 | AppliedConstraint::AppliedConstraint(Constraint &&constraint, |
123 | llvm::StringRef self, |
124 | std::vector<std::string> &&entities) |
125 | : constraint(constraint), self(std::string(self)), |
126 | entities(std::move(entities)) {} |
127 | |
128 | Constraint DenseMapInfo<Constraint>::getEmptyKey() { |
129 | return Constraint(RecordDenseMapInfo::getEmptyKey(), |
130 | Constraint::CK_Uncategorized); |
131 | } |
132 | |
133 | Constraint DenseMapInfo<Constraint>::getTombstoneKey() { |
134 | return Constraint(RecordDenseMapInfo::getTombstoneKey(), |
135 | Constraint::CK_Uncategorized); |
136 | } |
137 | |
138 | unsigned DenseMapInfo<Constraint>::getHashValue(Constraint constraint) { |
139 | if (constraint == getEmptyKey()) |
140 | return RecordDenseMapInfo::getHashValue(PtrVal: RecordDenseMapInfo::getEmptyKey()); |
141 | if (constraint == getTombstoneKey()) { |
142 | return RecordDenseMapInfo::getHashValue( |
143 | PtrVal: RecordDenseMapInfo::getTombstoneKey()); |
144 | } |
145 | return llvm::hash_combine(args: constraint.getPredicate(), args: constraint.getSummary()); |
146 | } |
147 | |
148 | bool DenseMapInfo<Constraint>::isEqual(Constraint lhs, Constraint rhs) { |
149 | if (lhs == rhs) |
150 | return true; |
151 | if (lhs == getEmptyKey() || lhs == getTombstoneKey()) |
152 | return false; |
153 | if (rhs == getEmptyKey() || rhs == getTombstoneKey()) |
154 | return false; |
155 | return lhs.getPredicate() == rhs.getPredicate() && |
156 | lhs.getSummary() == rhs.getSummary(); |
157 | } |
158 | |