1 | //===- Attribute.cpp - Attribute wrapper class ----------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Attribute wrapper to simplify using TableGen Record defining a MLIR |
10 | // Attribute. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/TableGen/Format.h" |
15 | #include "mlir/TableGen/Operator.h" |
16 | #include "llvm/TableGen/Record.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::tblgen; |
20 | |
21 | using llvm::DefInit; |
22 | using llvm::Init; |
23 | using llvm::Record; |
24 | using llvm::StringInit; |
25 | |
26 | // Returns the initializer's value as string if the given TableGen initializer |
27 | // is a code or string initializer. Returns the empty StringRef otherwise. |
28 | static StringRef getValueAsString(const Init *init) { |
29 | if (const auto *str = dyn_cast<StringInit>(Val: init)) |
30 | return str->getValue().trim(); |
31 | return {}; |
32 | } |
33 | |
34 | bool AttrConstraint::isSubClassOf(StringRef className) const { |
35 | return def->isSubClassOf(Name: className); |
36 | } |
37 | |
38 | Attribute::Attribute(const Record *record) : AttrConstraint(record) { |
39 | assert(record->isSubClassOf("Attr") && |
40 | "must be subclass of TableGen 'Attr' class"); |
41 | } |
42 | |
43 | Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} |
44 | |
45 | bool Attribute::isDerivedAttr() const { return isSubClassOf(className: "DerivedAttr"); } |
46 | |
47 | bool Attribute::isTypeAttr() const { return isSubClassOf(className: "TypeAttrBase"); } |
48 | |
49 | bool Attribute::isSymbolRefAttr() const { |
50 | StringRef defName = def->getName(); |
51 | if (defName == "SymbolRefAttr"|| defName == "FlatSymbolRefAttr") |
52 | return true; |
53 | return isSubClassOf(className: "SymbolRefAttr") || isSubClassOf(className: "FlatSymbolRefAttr"); |
54 | } |
55 | |
56 | bool Attribute::isEnumAttr() const { return isSubClassOf(className: "EnumAttrInfo"); } |
57 | |
58 | StringRef Attribute::getStorageType() const { |
59 | const auto *init = def->getValueInit(FieldName: "storageType"); |
60 | auto type = getValueAsString(init); |
61 | if (type.empty()) |
62 | return "::mlir::Attribute"; |
63 | return type; |
64 | } |
65 | |
66 | StringRef Attribute::getReturnType() const { |
67 | const auto *init = def->getValueInit(FieldName: "returnType"); |
68 | return getValueAsString(init); |
69 | } |
70 | |
71 | // Return the type constraint corresponding to the type of this attribute, or |
72 | // std::nullopt if this is not a TypedAttr. |
73 | std::optional<Type> Attribute::getValueType() const { |
74 | if (auto *defInit = dyn_cast<llvm::DefInit>(Val: def->getValueInit(FieldName: "valueType"))) |
75 | return Type(defInit->getDef()); |
76 | return std::nullopt; |
77 | } |
78 | |
79 | StringRef Attribute::getConvertFromStorageCall() const { |
80 | const auto *init = def->getValueInit(FieldName: "convertFromStorage"); |
81 | return getValueAsString(init); |
82 | } |
83 | |
84 | bool Attribute::isConstBuildable() const { |
85 | const auto *init = def->getValueInit(FieldName: "constBuilderCall"); |
86 | return !getValueAsString(init).empty(); |
87 | } |
88 | |
89 | StringRef Attribute::getConstBuilderTemplate() const { |
90 | const auto *init = def->getValueInit(FieldName: "constBuilderCall"); |
91 | return getValueAsString(init); |
92 | } |
93 | |
94 | Attribute Attribute::getBaseAttr() const { |
95 | if (const auto *defInit = |
96 | llvm::dyn_cast<llvm::DefInit>(Val: def->getValueInit(FieldName: "baseAttr"))) { |
97 | return Attribute(defInit).getBaseAttr(); |
98 | } |
99 | return *this; |
100 | } |
101 | |
102 | bool Attribute::hasDefaultValue() const { |
103 | const auto *init = def->getValueInit(FieldName: "defaultValue"); |
104 | return !getValueAsString(init).empty(); |
105 | } |
106 | |
107 | StringRef Attribute::getDefaultValue() const { |
108 | const auto *init = def->getValueInit(FieldName: "defaultValue"); |
109 | return getValueAsString(init); |
110 | } |
111 | |
112 | bool Attribute::isOptional() const { return def->getValueAsBit(FieldName: "isOptional"); } |
113 | |
114 | StringRef Attribute::getAttrDefName() const { |
115 | if (def->isAnonymous()) { |
116 | return getBaseAttr().def->getName(); |
117 | } |
118 | return def->getName(); |
119 | } |
120 | |
121 | StringRef Attribute::getDerivedCodeBody() const { |
122 | assert(isDerivedAttr() && "only derived attribute has 'body' field"); |
123 | return def->getValueAsString(FieldName: "body"); |
124 | } |
125 | |
126 | Dialect Attribute::getDialect() const { |
127 | const llvm::RecordVal *record = def->getValue(Name: "dialect"); |
128 | if (record && record->getValue()) { |
129 | if (DefInit *init = dyn_cast<DefInit>(Val: record->getValue())) |
130 | return Dialect(init->getDef()); |
131 | } |
132 | return Dialect(nullptr); |
133 | } |
134 | |
135 | const llvm::Record &Attribute::getDef() const { return *def; } |
136 | |
137 | ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { |
138 | assert(def->isSubClassOf("ConstantAttr") && |
139 | "must be subclass of TableGen 'ConstantAttr' class"); |
140 | } |
141 | |
142 | Attribute ConstantAttr::getAttribute() const { |
143 | return Attribute(def->getValueAsDef(FieldName: "attr")); |
144 | } |
145 | |
146 | StringRef ConstantAttr::getConstantValue() const { |
147 | return def->getValueAsString(FieldName: "value"); |
148 | } |
149 | |
150 | EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) { |
151 | assert(isSubClassOf("EnumAttrCaseInfo") && |
152 | "must be subclass of TableGen 'EnumAttrInfo' class"); |
153 | } |
154 | |
155 | EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) |
156 | : EnumAttrCase(init->getDef()) {} |
157 | |
158 | StringRef EnumAttrCase::getSymbol() const { |
159 | return def->getValueAsString(FieldName: "symbol"); |
160 | } |
161 | |
162 | StringRef EnumAttrCase::getStr() const { return def->getValueAsString(FieldName: "str"); } |
163 | |
164 | int64_t EnumAttrCase::getValue() const { return def->getValueAsInt(FieldName: "value"); } |
165 | |
166 | const llvm::Record &EnumAttrCase::getDef() const { return *def; } |
167 | |
168 | EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { |
169 | assert(isSubClassOf("EnumAttrInfo") && |
170 | "must be subclass of TableGen 'EnumAttr' class"); |
171 | } |
172 | |
173 | EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} |
174 | |
175 | EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} |
176 | |
177 | bool EnumAttr::classof(const Attribute *attr) { |
178 | return attr->isSubClassOf(className: "EnumAttrInfo"); |
179 | } |
180 | |
181 | bool EnumAttr::isBitEnum() const { return isSubClassOf(className: "BitEnumAttr"); } |
182 | |
183 | StringRef EnumAttr::getEnumClassName() const { |
184 | return def->getValueAsString(FieldName: "className"); |
185 | } |
186 | |
187 | StringRef EnumAttr::getCppNamespace() const { |
188 | return def->getValueAsString(FieldName: "cppNamespace"); |
189 | } |
190 | |
191 | StringRef EnumAttr::getUnderlyingType() const { |
192 | return def->getValueAsString(FieldName: "underlyingType"); |
193 | } |
194 | |
195 | StringRef EnumAttr::getUnderlyingToSymbolFnName() const { |
196 | return def->getValueAsString(FieldName: "underlyingToSymbolFnName"); |
197 | } |
198 | |
199 | StringRef EnumAttr::getStringToSymbolFnName() const { |
200 | return def->getValueAsString(FieldName: "stringToSymbolFnName"); |
201 | } |
202 | |
203 | StringRef EnumAttr::getSymbolToStringFnName() const { |
204 | return def->getValueAsString(FieldName: "symbolToStringFnName"); |
205 | } |
206 | |
207 | StringRef EnumAttr::getSymbolToStringFnRetType() const { |
208 | return def->getValueAsString(FieldName: "symbolToStringFnRetType"); |
209 | } |
210 | |
211 | StringRef EnumAttr::getMaxEnumValFnName() const { |
212 | return def->getValueAsString(FieldName: "maxEnumValFnName"); |
213 | } |
214 | |
215 | std::vector<EnumAttrCase> EnumAttr::getAllCases() const { |
216 | const auto *inits = def->getValueAsListInit(FieldName: "enumerants"); |
217 | |
218 | std::vector<EnumAttrCase> cases; |
219 | cases.reserve(n: inits->size()); |
220 | |
221 | for (const llvm::Init *init : *inits) { |
222 | cases.emplace_back(args: cast<llvm::DefInit>(Val: init)); |
223 | } |
224 | |
225 | return cases; |
226 | } |
227 | |
228 | bool EnumAttr::genSpecializedAttr() const { |
229 | return def->getValueAsBit(FieldName: "genSpecializedAttr"); |
230 | } |
231 | |
232 | llvm::Record *EnumAttr::getBaseAttrClass() const { |
233 | return def->getValueAsDef(FieldName: "baseAttrClass"); |
234 | } |
235 | |
236 | StringRef EnumAttr::getSpecializedAttrClassName() const { |
237 | return def->getValueAsString(FieldName: "specializedAttrClassName"); |
238 | } |
239 | |
240 | bool EnumAttr::printBitEnumPrimaryGroups() const { |
241 | return def->getValueAsBit(FieldName: "printBitEnumPrimaryGroups"); |
242 | } |
243 | |
244 | const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; |
245 |