1 | //===-- FIRAttr.cpp -------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
14 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
15 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
16 | #include "mlir/IR/AttributeSupport.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/DialectImplementation.h" |
20 | #include "llvm/ADT/SmallString.h" |
21 | #include "llvm/ADT/StringExtras.h" |
22 | #include "llvm/ADT/TypeSwitch.h" |
23 | |
24 | #include "flang/Optimizer/Dialect/FIREnumAttr.cpp.inc" |
25 | #define GET_ATTRDEF_CLASSES |
26 | #include "flang/Optimizer/Dialect/FIRAttr.cpp.inc" |
27 | |
28 | using namespace fir; |
29 | |
30 | namespace fir::detail { |
31 | |
32 | struct RealAttributeStorage : public mlir::AttributeStorage { |
33 | using KeyTy = std::pair<int, llvm::APFloat>; |
34 | |
35 | RealAttributeStorage(int kind, const llvm::APFloat &value) |
36 | : kind(kind), value(value) {} |
37 | RealAttributeStorage(const KeyTy &key) |
38 | : RealAttributeStorage(key.first, key.second) {} |
39 | |
40 | static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(arg: key); } |
41 | |
42 | bool operator==(const KeyTy &key) const { |
43 | return key.first == kind && |
44 | key.second.compare(RHS: value) == llvm::APFloatBase::cmpEqual; |
45 | } |
46 | |
47 | static RealAttributeStorage * |
48 | construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) { |
49 | return new (allocator.allocate<RealAttributeStorage>()) |
50 | RealAttributeStorage(key); |
51 | } |
52 | |
53 | KindTy getFKind() const { return kind; } |
54 | llvm::APFloat getValue() const { return value; } |
55 | |
56 | private: |
57 | int kind; |
58 | llvm::APFloat value; |
59 | }; |
60 | |
61 | /// An attribute representing a reference to a type. |
62 | struct TypeAttributeStorage : public mlir::AttributeStorage { |
63 | using KeyTy = mlir::Type; |
64 | |
65 | TypeAttributeStorage(mlir::Type value) : value(value) { |
66 | assert(value && "must not be of Type null" ); |
67 | } |
68 | |
69 | /// Key equality function. |
70 | bool operator==(const KeyTy &key) const { return key == value; } |
71 | |
72 | /// Construct a new storage instance. |
73 | static TypeAttributeStorage * |
74 | construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) { |
75 | return new (allocator.allocate<TypeAttributeStorage>()) |
76 | TypeAttributeStorage(key); |
77 | } |
78 | |
79 | mlir::Type getType() const { return value; } |
80 | |
81 | private: |
82 | mlir::Type value; |
83 | }; |
84 | } // namespace fir::detail |
85 | |
86 | //===----------------------------------------------------------------------===// |
87 | // Attributes for SELECT TYPE |
88 | //===----------------------------------------------------------------------===// |
89 | |
90 | ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) { |
91 | return Base::get(value.getContext(), value); |
92 | } |
93 | |
94 | mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); } |
95 | |
96 | SubclassAttr fir::SubclassAttr::get(mlir::Type value) { |
97 | return Base::get(value.getContext(), value); |
98 | } |
99 | |
100 | mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); } |
101 | |
102 | //===----------------------------------------------------------------------===// |
103 | // Attributes for SELECT CASE |
104 | //===----------------------------------------------------------------------===// |
105 | |
106 | using AttributeUniquer = mlir::detail::AttributeUniquer; |
107 | |
108 | ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) { |
109 | return AttributeUniquer::get<ClosedIntervalAttr>(ctxt); |
110 | } |
111 | |
112 | UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) { |
113 | return AttributeUniquer::get<UpperBoundAttr>(ctxt); |
114 | } |
115 | |
116 | LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) { |
117 | return AttributeUniquer::get<LowerBoundAttr>(ctxt); |
118 | } |
119 | |
120 | PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) { |
121 | return AttributeUniquer::get<PointIntervalAttr>(ctxt); |
122 | } |
123 | |
124 | //===----------------------------------------------------------------------===// |
125 | // RealAttr |
126 | //===----------------------------------------------------------------------===// |
127 | |
128 | RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt, |
129 | const RealAttr::ValueType &key) { |
130 | return Base::get(ctxt, key); |
131 | } |
132 | |
133 | KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); } |
134 | |
135 | llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); } |
136 | |
137 | //===----------------------------------------------------------------------===// |
138 | // FIR attribute parsing |
139 | //===----------------------------------------------------------------------===// |
140 | |
141 | static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect, |
142 | mlir::DialectAsmParser &parser, |
143 | mlir::Type type) { |
144 | int kind = 0; |
145 | if (parser.parseLess() || parser.parseInteger(result&: kind) || parser.parseComma()) { |
146 | parser.emitError(loc: parser.getNameLoc(), message: "expected '<' kind ','" ); |
147 | return {}; |
148 | } |
149 | KindMapping kindMap(dialect->getContext()); |
150 | llvm::APFloat value(0.); |
151 | if (parser.parseOptionalKeyword(keyword: "i" )) { |
152 | // `i` not present, so literal float must be present |
153 | double dontCare; |
154 | if (parser.parseFloat(result&: dontCare) || parser.parseGreater()) { |
155 | parser.emitError(loc: parser.getNameLoc(), message: "expected real constant '>'" ); |
156 | return {}; |
157 | } |
158 | auto fltStr = parser.getFullSymbolSpec() |
159 | .drop_until(F: [](char c) { return c == ','; }) |
160 | .drop_front() |
161 | .drop_while(F: [](char c) { return c == ' ' || c == '\t'; }) |
162 | .take_until(F: [](char c) { |
163 | return c == '>' || c == ' ' || c == '\t'; |
164 | }); |
165 | value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr); |
166 | } else { |
167 | // `i` is present, so literal bitstring (hex) must be present |
168 | llvm::StringRef hex; |
169 | if (parser.parseKeyword(keyword: &hex) || parser.parseGreater()) { |
170 | parser.emitError(loc: parser.getNameLoc(), message: "expected real constant '>'" ); |
171 | return {}; |
172 | } |
173 | const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind); |
174 | unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem); |
175 | auto bits = llvm::APInt(numBits, hex.drop_front(), 16); |
176 | value = llvm::APFloat(sem, bits); |
177 | } |
178 | return RealAttr::get(dialect->getContext(), {kind, value}); |
179 | } |
180 | |
181 | mlir::Attribute fir::FortranVariableFlagsAttr::parse(mlir::AsmParser &parser, |
182 | mlir::Type type) { |
183 | if (mlir::failed(parser.parseLess())) |
184 | return {}; |
185 | |
186 | fir::FortranVariableFlagsEnum flags = {}; |
187 | if (mlir::failed(parser.parseOptionalGreater())) { |
188 | auto parseFlags = [&]() -> mlir::ParseResult { |
189 | llvm::StringRef elemName; |
190 | if (mlir::failed(parser.parseKeyword(&elemName))) |
191 | return mlir::failure(); |
192 | |
193 | auto elem = fir::symbolizeFortranVariableFlagsEnum(elemName); |
194 | if (!elem) |
195 | return parser.emitError(parser.getNameLoc(), |
196 | "Unknown fortran variable attribute: " ) |
197 | << elemName; |
198 | |
199 | flags = flags | *elem; |
200 | return mlir::success(); |
201 | }; |
202 | if (mlir::failed(parser.parseCommaSeparatedList(parseFlags)) || |
203 | parser.parseGreater()) |
204 | return {}; |
205 | } |
206 | |
207 | return FortranVariableFlagsAttr::get(parser.getContext(), flags); |
208 | } |
209 | |
210 | mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect, |
211 | mlir::DialectAsmParser &parser, |
212 | mlir::Type type) { |
213 | auto loc = parser.getNameLoc(); |
214 | llvm::StringRef attrName; |
215 | mlir::Attribute attr; |
216 | mlir::OptionalParseResult result = |
217 | generatedAttributeParser(parser, &attrName, type, attr); |
218 | if (result.has_value()) |
219 | return attr; |
220 | if (attrName.empty()) |
221 | return {}; // error reported by generatedAttributeParser |
222 | |
223 | if (attrName == ExactTypeAttr::getAttrName()) { |
224 | mlir::Type type; |
225 | if (parser.parseLess() || parser.parseType(result&: type) || parser.parseGreater()) { |
226 | parser.emitError(loc, message: "expected a type" ); |
227 | return {}; |
228 | } |
229 | return ExactTypeAttr::get(type); |
230 | } |
231 | if (attrName == SubclassAttr::getAttrName()) { |
232 | mlir::Type type; |
233 | if (parser.parseLess() || parser.parseType(result&: type) || parser.parseGreater()) { |
234 | parser.emitError(loc, message: "expected a subtype" ); |
235 | return {}; |
236 | } |
237 | return SubclassAttr::get(type); |
238 | } |
239 | if (attrName == PointIntervalAttr::getAttrName()) |
240 | return PointIntervalAttr::get(dialect->getContext()); |
241 | if (attrName == LowerBoundAttr::getAttrName()) |
242 | return LowerBoundAttr::get(dialect->getContext()); |
243 | if (attrName == UpperBoundAttr::getAttrName()) |
244 | return UpperBoundAttr::get(dialect->getContext()); |
245 | if (attrName == ClosedIntervalAttr::getAttrName()) |
246 | return ClosedIntervalAttr::get(dialect->getContext()); |
247 | if (attrName == RealAttr::getAttrName()) |
248 | return parseFirRealAttr(dialect, parser, type); |
249 | |
250 | parser.emitError(loc, message: "unknown FIR attribute: " ) << attrName; |
251 | return {}; |
252 | } |
253 | |
254 | //===----------------------------------------------------------------------===// |
255 | // FIR attribute pretty printer |
256 | //===----------------------------------------------------------------------===// |
257 | |
258 | void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter &printer) const { |
259 | printer << "<" ; |
260 | printer << fir::stringifyFortranVariableFlagsEnum(this->getFlags()); |
261 | printer << ">" ; |
262 | } |
263 | |
264 | void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr, |
265 | mlir::DialectAsmPrinter &p) { |
266 | auto &os = p.getStream(); |
267 | if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) { |
268 | os << fir::ExactTypeAttr::getAttrName() << '<'; |
269 | p.printType(type: exact.getType()); |
270 | os << '>'; |
271 | } else if (auto sub = attr.dyn_cast<fir::SubclassAttr>()) { |
272 | os << fir::SubclassAttr::getAttrName() << '<'; |
273 | p.printType(type: sub.getType()); |
274 | os << '>'; |
275 | } else if (attr.dyn_cast_or_null<fir::PointIntervalAttr>()) { |
276 | os << fir::PointIntervalAttr::getAttrName(); |
277 | } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { |
278 | os << fir::ClosedIntervalAttr::getAttrName(); |
279 | } else if (attr.dyn_cast_or_null<fir::LowerBoundAttr>()) { |
280 | os << fir::LowerBoundAttr::getAttrName(); |
281 | } else if (attr.dyn_cast_or_null<fir::UpperBoundAttr>()) { |
282 | os << fir::UpperBoundAttr::getAttrName(); |
283 | } else if (auto a = attr.dyn_cast_or_null<fir::RealAttr>()) { |
284 | os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x" ; |
285 | llvm::SmallString<40> ss; |
286 | a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16); |
287 | os << ss << '>'; |
288 | } else if (mlir::failed(result: generatedAttributePrinter(attr, p))) { |
289 | // don't know how to print the attribute, so use a default |
290 | os << "<(unknown attribute)>" ; |
291 | } |
292 | } |
293 | |
294 | //===----------------------------------------------------------------------===// |
295 | // FIROpsDialect |
296 | //===----------------------------------------------------------------------===// |
297 | |
298 | void FIROpsDialect::registerAttributes() { |
299 | addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr, |
300 | LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr, |
301 | UpperBoundAttr, CUDADataAttributeAttr, CUDAProcAttributeAttr, |
302 | CUDALaunchBoundsAttr, CUDAClusterDimsAttr, |
303 | CUDADataTransferKindAttr>(); |
304 | } |
305 | |