1//===-- FIRAttr.cpp -------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/Dialect/FIRAttr.h"
14#include "flang/Optimizer/Dialect/FIRDialect.h"
15#include "flang/Optimizer/Dialect/Support/KindMapping.h"
16#include "mlir/IR/AttributeSupport.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "llvm/ADT/SmallString.h"
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24#include "flang/Optimizer/Dialect/FIREnumAttr.cpp.inc"
25#define GET_ATTRDEF_CLASSES
26#include "flang/Optimizer/Dialect/FIRAttr.cpp.inc"
27
28using namespace fir;
29
30namespace fir::detail {
31
32struct RealAttributeStorage : public mlir::AttributeStorage {
33 using KeyTy = std::pair<int, llvm::APFloat>;
34
35 RealAttributeStorage(int kind, const llvm::APFloat &value)
36 : kind(kind), value(value) {}
37 RealAttributeStorage(const KeyTy &key)
38 : RealAttributeStorage(key.first, key.second) {}
39
40 static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(arg: key); }
41
42 bool operator==(const KeyTy &key) const {
43 return key.first == kind &&
44 key.second.compare(RHS: value) == llvm::APFloatBase::cmpEqual;
45 }
46
47 static RealAttributeStorage *
48 construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
49 return new (allocator.allocate<RealAttributeStorage>())
50 RealAttributeStorage(key);
51 }
52
53 KindTy getFKind() const { return kind; }
54 llvm::APFloat getValue() const { return value; }
55
56private:
57 int kind;
58 llvm::APFloat value;
59};
60
61/// An attribute representing a reference to a type.
62struct TypeAttributeStorage : public mlir::AttributeStorage {
63 using KeyTy = mlir::Type;
64
65 TypeAttributeStorage(mlir::Type value) : value(value) {
66 assert(value && "must not be of Type null");
67 }
68
69 /// Key equality function.
70 bool operator==(const KeyTy &key) const { return key == value; }
71
72 /// Construct a new storage instance.
73 static TypeAttributeStorage *
74 construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
75 return new (allocator.allocate<TypeAttributeStorage>())
76 TypeAttributeStorage(key);
77 }
78
79 mlir::Type getType() const { return value; }
80
81private:
82 mlir::Type value;
83};
84} // namespace fir::detail
85
86//===----------------------------------------------------------------------===//
87// Attributes for SELECT TYPE
88//===----------------------------------------------------------------------===//
89
90ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) {
91 return Base::get(value.getContext(), value);
92}
93
94mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); }
95
96SubclassAttr fir::SubclassAttr::get(mlir::Type value) {
97 return Base::get(value.getContext(), value);
98}
99
100mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); }
101
102//===----------------------------------------------------------------------===//
103// Attributes for SELECT CASE
104//===----------------------------------------------------------------------===//
105
106using AttributeUniquer = mlir::detail::AttributeUniquer;
107
108ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
109 return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
110}
111
112UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
113 return AttributeUniquer::get<UpperBoundAttr>(ctxt);
114}
115
116LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
117 return AttributeUniquer::get<LowerBoundAttr>(ctxt);
118}
119
120PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
121 return AttributeUniquer::get<PointIntervalAttr>(ctxt);
122}
123
124//===----------------------------------------------------------------------===//
125// RealAttr
126//===----------------------------------------------------------------------===//
127
128RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt,
129 const RealAttr::ValueType &key) {
130 return Base::get(ctxt, key);
131}
132
133KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); }
134
135llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); }
136
137//===----------------------------------------------------------------------===//
138// FIR attribute parsing
139//===----------------------------------------------------------------------===//
140
141static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
142 mlir::DialectAsmParser &parser,
143 mlir::Type type) {
144 int kind = 0;
145 if (parser.parseLess() || parser.parseInteger(result&: kind) || parser.parseComma()) {
146 parser.emitError(loc: parser.getNameLoc(), message: "expected '<' kind ','");
147 return {};
148 }
149 KindMapping kindMap(dialect->getContext());
150 llvm::APFloat value(0.);
151 if (parser.parseOptionalKeyword(keyword: "i")) {
152 // `i` not present, so literal float must be present
153 double dontCare;
154 if (parser.parseFloat(result&: dontCare) || parser.parseGreater()) {
155 parser.emitError(loc: parser.getNameLoc(), message: "expected real constant '>'");
156 return {};
157 }
158 auto fltStr = parser.getFullSymbolSpec()
159 .drop_until(F: [](char c) { return c == ','; })
160 .drop_front()
161 .drop_while(F: [](char c) { return c == ' ' || c == '\t'; })
162 .take_until(F: [](char c) {
163 return c == '>' || c == ' ' || c == '\t';
164 });
165 value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr);
166 } else {
167 // `i` is present, so literal bitstring (hex) must be present
168 llvm::StringRef hex;
169 if (parser.parseKeyword(keyword: &hex) || parser.parseGreater()) {
170 parser.emitError(loc: parser.getNameLoc(), message: "expected real constant '>'");
171 return {};
172 }
173 const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind);
174 unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem);
175 auto bits = llvm::APInt(numBits, hex.drop_front(), 16);
176 value = llvm::APFloat(sem, bits);
177 }
178 return RealAttr::get(dialect->getContext(), {kind, value});
179}
180
181mlir::Attribute fir::FortranVariableFlagsAttr::parse(mlir::AsmParser &parser,
182 mlir::Type type) {
183 if (mlir::failed(parser.parseLess()))
184 return {};
185
186 fir::FortranVariableFlagsEnum flags = {};
187 if (mlir::failed(parser.parseOptionalGreater())) {
188 auto parseFlags = [&]() -> mlir::ParseResult {
189 llvm::StringRef elemName;
190 if (mlir::failed(parser.parseKeyword(&elemName)))
191 return mlir::failure();
192
193 auto elem = fir::symbolizeFortranVariableFlagsEnum(elemName);
194 if (!elem)
195 return parser.emitError(parser.getNameLoc(),
196 "Unknown fortran variable attribute: ")
197 << elemName;
198
199 flags = flags | *elem;
200 return mlir::success();
201 };
202 if (mlir::failed(parser.parseCommaSeparatedList(parseFlags)) ||
203 parser.parseGreater())
204 return {};
205 }
206
207 return FortranVariableFlagsAttr::get(parser.getContext(), flags);
208}
209
210mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect,
211 mlir::DialectAsmParser &parser,
212 mlir::Type type) {
213 auto loc = parser.getNameLoc();
214 llvm::StringRef attrName;
215 mlir::Attribute attr;
216 mlir::OptionalParseResult result =
217 generatedAttributeParser(parser, &attrName, type, attr);
218 if (result.has_value())
219 return attr;
220 if (attrName.empty())
221 return {}; // error reported by generatedAttributeParser
222
223 if (attrName == ExactTypeAttr::getAttrName()) {
224 mlir::Type type;
225 if (parser.parseLess() || parser.parseType(result&: type) || parser.parseGreater()) {
226 parser.emitError(loc, message: "expected a type");
227 return {};
228 }
229 return ExactTypeAttr::get(type);
230 }
231 if (attrName == SubclassAttr::getAttrName()) {
232 mlir::Type type;
233 if (parser.parseLess() || parser.parseType(result&: type) || parser.parseGreater()) {
234 parser.emitError(loc, message: "expected a subtype");
235 return {};
236 }
237 return SubclassAttr::get(type);
238 }
239 if (attrName == PointIntervalAttr::getAttrName())
240 return PointIntervalAttr::get(dialect->getContext());
241 if (attrName == LowerBoundAttr::getAttrName())
242 return LowerBoundAttr::get(dialect->getContext());
243 if (attrName == UpperBoundAttr::getAttrName())
244 return UpperBoundAttr::get(dialect->getContext());
245 if (attrName == ClosedIntervalAttr::getAttrName())
246 return ClosedIntervalAttr::get(dialect->getContext());
247 if (attrName == RealAttr::getAttrName())
248 return parseFirRealAttr(dialect, parser, type);
249
250 parser.emitError(loc, message: "unknown FIR attribute: ") << attrName;
251 return {};
252}
253
254//===----------------------------------------------------------------------===//
255// FIR attribute pretty printer
256//===----------------------------------------------------------------------===//
257
258void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter &printer) const {
259 printer << "<";
260 printer << fir::stringifyFortranVariableFlagsEnum(this->getFlags());
261 printer << ">";
262}
263
264void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
265 mlir::DialectAsmPrinter &p) {
266 auto &os = p.getStream();
267 if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) {
268 os << fir::ExactTypeAttr::getAttrName() << '<';
269 p.printType(type: exact.getType());
270 os << '>';
271 } else if (auto sub = attr.dyn_cast<fir::SubclassAttr>()) {
272 os << fir::SubclassAttr::getAttrName() << '<';
273 p.printType(type: sub.getType());
274 os << '>';
275 } else if (attr.dyn_cast_or_null<fir::PointIntervalAttr>()) {
276 os << fir::PointIntervalAttr::getAttrName();
277 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
278 os << fir::ClosedIntervalAttr::getAttrName();
279 } else if (attr.dyn_cast_or_null<fir::LowerBoundAttr>()) {
280 os << fir::LowerBoundAttr::getAttrName();
281 } else if (attr.dyn_cast_or_null<fir::UpperBoundAttr>()) {
282 os << fir::UpperBoundAttr::getAttrName();
283 } else if (auto a = attr.dyn_cast_or_null<fir::RealAttr>()) {
284 os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
285 llvm::SmallString<40> ss;
286 a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
287 os << ss << '>';
288 } else if (mlir::failed(result: generatedAttributePrinter(attr, p))) {
289 // don't know how to print the attribute, so use a default
290 os << "<(unknown attribute)>";
291 }
292}
293
294//===----------------------------------------------------------------------===//
295// FIROpsDialect
296//===----------------------------------------------------------------------===//
297
298void FIROpsDialect::registerAttributes() {
299 addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
300 LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
301 UpperBoundAttr, CUDADataAttributeAttr, CUDAProcAttributeAttr,
302 CUDALaunchBoundsAttr, CUDAClusterDimsAttr,
303 CUDADataTransferKindAttr>();
304}
305

source code of flang/lib/Optimizer/Dialect/FIRAttr.cpp