1//===- Attribute.cpp - Attribute wrapper class ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Attribute wrapper to simplify using TableGen Record defining a MLIR
10// Attribute.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/TableGen/Format.h"
15#include "mlir/TableGen/Operator.h"
16#include "llvm/TableGen/Record.h"
17
18using namespace mlir;
19using namespace mlir::tblgen;
20
21using llvm::DefInit;
22using llvm::Init;
23using llvm::Record;
24using llvm::StringInit;
25
26// Returns the initializer's value as string if the given TableGen initializer
27// is a code or string initializer. Returns the empty StringRef otherwise.
28static StringRef getValueAsString(const Init *init) {
29 if (const auto *str = dyn_cast<StringInit>(Val: init))
30 return str->getValue().trim();
31 return {};
32}
33
34bool AttrConstraint::isSubClassOf(StringRef className) const {
35 return def->isSubClassOf(Name: className);
36}
37
38Attribute::Attribute(const Record *record) : AttrConstraint(record) {
39 assert(record->isSubClassOf("Attr") &&
40 "must be subclass of TableGen 'Attr' class");
41}
42
43Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
44
45bool Attribute::isDerivedAttr() const { return isSubClassOf(className: "DerivedAttr"); }
46
47bool Attribute::isTypeAttr() const { return isSubClassOf(className: "TypeAttrBase"); }
48
49bool Attribute::isSymbolRefAttr() const {
50 StringRef defName = def->getName();
51 if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
52 return true;
53 return isSubClassOf(className: "SymbolRefAttr") || isSubClassOf(className: "FlatSymbolRefAttr");
54}
55
56bool Attribute::isEnumAttr() const { return isSubClassOf(className: "EnumAttrInfo"); }
57
58StringRef Attribute::getStorageType() const {
59 const auto *init = def->getValueInit(FieldName: "storageType");
60 auto type = getValueAsString(init);
61 if (type.empty())
62 return "::mlir::Attribute";
63 return type;
64}
65
66StringRef Attribute::getReturnType() const {
67 const auto *init = def->getValueInit(FieldName: "returnType");
68 return getValueAsString(init);
69}
70
71// Return the type constraint corresponding to the type of this attribute, or
72// std::nullopt if this is not a TypedAttr.
73std::optional<Type> Attribute::getValueType() const {
74 if (auto *defInit = dyn_cast<llvm::DefInit>(Val: def->getValueInit(FieldName: "valueType")))
75 return Type(defInit->getDef());
76 return std::nullopt;
77}
78
79StringRef Attribute::getConvertFromStorageCall() const {
80 const auto *init = def->getValueInit(FieldName: "convertFromStorage");
81 return getValueAsString(init);
82}
83
84bool Attribute::isConstBuildable() const {
85 const auto *init = def->getValueInit(FieldName: "constBuilderCall");
86 return !getValueAsString(init).empty();
87}
88
89StringRef Attribute::getConstBuilderTemplate() const {
90 const auto *init = def->getValueInit(FieldName: "constBuilderCall");
91 return getValueAsString(init);
92}
93
94Attribute Attribute::getBaseAttr() const {
95 if (const auto *defInit =
96 llvm::dyn_cast<llvm::DefInit>(Val: def->getValueInit(FieldName: "baseAttr"))) {
97 return Attribute(defInit).getBaseAttr();
98 }
99 return *this;
100}
101
102bool Attribute::hasDefaultValue() const {
103 const auto *init = def->getValueInit(FieldName: "defaultValue");
104 return !getValueAsString(init).empty();
105}
106
107StringRef Attribute::getDefaultValue() const {
108 const auto *init = def->getValueInit(FieldName: "defaultValue");
109 return getValueAsString(init);
110}
111
112bool Attribute::isOptional() const { return def->getValueAsBit(FieldName: "isOptional"); }
113
114StringRef Attribute::getAttrDefName() const {
115 if (def->isAnonymous()) {
116 return getBaseAttr().def->getName();
117 }
118 return def->getName();
119}
120
121StringRef Attribute::getDerivedCodeBody() const {
122 assert(isDerivedAttr() && "only derived attribute has 'body' field");
123 return def->getValueAsString(FieldName: "body");
124}
125
126Dialect Attribute::getDialect() const {
127 const llvm::RecordVal *record = def->getValue(Name: "dialect");
128 if (record && record->getValue()) {
129 if (DefInit *init = dyn_cast<DefInit>(Val: record->getValue()))
130 return Dialect(init->getDef());
131 }
132 return Dialect(nullptr);
133}
134
135const llvm::Record &Attribute::getDef() const { return *def; }
136
137ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
138 assert(def->isSubClassOf("ConstantAttr") &&
139 "must be subclass of TableGen 'ConstantAttr' class");
140}
141
142Attribute ConstantAttr::getAttribute() const {
143 return Attribute(def->getValueAsDef(FieldName: "attr"));
144}
145
146StringRef ConstantAttr::getConstantValue() const {
147 return def->getValueAsString(FieldName: "value");
148}
149
150EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
151 assert(isSubClassOf("EnumAttrCaseInfo") &&
152 "must be subclass of TableGen 'EnumAttrInfo' class");
153}
154
155EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
156 : EnumAttrCase(init->getDef()) {}
157
158StringRef EnumAttrCase::getSymbol() const {
159 return def->getValueAsString(FieldName: "symbol");
160}
161
162StringRef EnumAttrCase::getStr() const { return def->getValueAsString(FieldName: "str"); }
163
164int64_t EnumAttrCase::getValue() const { return def->getValueAsInt(FieldName: "value"); }
165
166const llvm::Record &EnumAttrCase::getDef() const { return *def; }
167
168EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
169 assert(isSubClassOf("EnumAttrInfo") &&
170 "must be subclass of TableGen 'EnumAttr' class");
171}
172
173EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
174
175EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
176
177bool EnumAttr::classof(const Attribute *attr) {
178 return attr->isSubClassOf(className: "EnumAttrInfo");
179}
180
181bool EnumAttr::isBitEnum() const { return isSubClassOf(className: "BitEnumAttr"); }
182
183StringRef EnumAttr::getEnumClassName() const {
184 return def->getValueAsString(FieldName: "className");
185}
186
187StringRef EnumAttr::getCppNamespace() const {
188 return def->getValueAsString(FieldName: "cppNamespace");
189}
190
191StringRef EnumAttr::getUnderlyingType() const {
192 return def->getValueAsString(FieldName: "underlyingType");
193}
194
195StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
196 return def->getValueAsString(FieldName: "underlyingToSymbolFnName");
197}
198
199StringRef EnumAttr::getStringToSymbolFnName() const {
200 return def->getValueAsString(FieldName: "stringToSymbolFnName");
201}
202
203StringRef EnumAttr::getSymbolToStringFnName() const {
204 return def->getValueAsString(FieldName: "symbolToStringFnName");
205}
206
207StringRef EnumAttr::getSymbolToStringFnRetType() const {
208 return def->getValueAsString(FieldName: "symbolToStringFnRetType");
209}
210
211StringRef EnumAttr::getMaxEnumValFnName() const {
212 return def->getValueAsString(FieldName: "maxEnumValFnName");
213}
214
215std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
216 const auto *inits = def->getValueAsListInit(FieldName: "enumerants");
217
218 std::vector<EnumAttrCase> cases;
219 cases.reserve(n: inits->size());
220
221 for (const llvm::Init *init : *inits) {
222 cases.emplace_back(args: cast<llvm::DefInit>(Val: init));
223 }
224
225 return cases;
226}
227
228bool EnumAttr::genSpecializedAttr() const {
229 return def->getValueAsBit(FieldName: "genSpecializedAttr");
230}
231
232llvm::Record *EnumAttr::getBaseAttrClass() const {
233 return def->getValueAsDef(FieldName: "baseAttrClass");
234}
235
236StringRef EnumAttr::getSpecializedAttrClassName() const {
237 return def->getValueAsString(FieldName: "specializedAttrClassName");
238}
239
240bool EnumAttr::printBitEnumPrimaryGroups() const {
241 return def->getValueAsBit(FieldName: "printBitEnumPrimaryGroups");
242}
243
244const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
245

source code of mlir/lib/TableGen/Attribute.cpp