| 1 | //===-- FIRAttr.cpp -------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
| 14 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
| 15 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
| 16 | #include "mlir/IR/AttributeSupport.h" |
| 17 | #include "mlir/IR/Builders.h" |
| 18 | #include "mlir/IR/BuiltinTypes.h" |
| 19 | #include "mlir/IR/DialectImplementation.h" |
| 20 | #include "llvm/ADT/SmallString.h" |
| 21 | #include "llvm/ADT/StringExtras.h" |
| 22 | #include "llvm/ADT/TypeSwitch.h" |
| 23 | |
| 24 | #include "flang/Optimizer/Dialect/FIREnumAttr.cpp.inc" |
| 25 | #define GET_ATTRDEF_CLASSES |
| 26 | #include "flang/Optimizer/Dialect/FIRAttr.cpp.inc" |
| 27 | |
| 28 | using namespace fir; |
| 29 | |
| 30 | namespace fir::detail { |
| 31 | |
| 32 | struct RealAttributeStorage : public mlir::AttributeStorage { |
| 33 | using KeyTy = std::pair<int, llvm::APFloat>; |
| 34 | |
| 35 | RealAttributeStorage(int kind, const llvm::APFloat &value) |
| 36 | : kind(kind), value(value) {} |
| 37 | RealAttributeStorage(const KeyTy &key) |
| 38 | : RealAttributeStorage(key.first, key.second) {} |
| 39 | |
| 40 | static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(arg: key); } |
| 41 | |
| 42 | bool operator==(const KeyTy &key) const { |
| 43 | return key.first == kind && |
| 44 | key.second.compare(RHS: value) == llvm::APFloatBase::cmpEqual; |
| 45 | } |
| 46 | |
| 47 | static RealAttributeStorage * |
| 48 | construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) { |
| 49 | return new (allocator.allocate<RealAttributeStorage>()) |
| 50 | RealAttributeStorage(key); |
| 51 | } |
| 52 | |
| 53 | KindTy getFKind() const { return kind; } |
| 54 | llvm::APFloat getValue() const { return value; } |
| 55 | |
| 56 | private: |
| 57 | int kind; |
| 58 | llvm::APFloat value; |
| 59 | }; |
| 60 | |
| 61 | /// An attribute representing a reference to a type. |
| 62 | struct TypeAttributeStorage : public mlir::AttributeStorage { |
| 63 | using KeyTy = mlir::Type; |
| 64 | |
| 65 | TypeAttributeStorage(mlir::Type value) : value(value) { |
| 66 | assert(value && "must not be of Type null" ); |
| 67 | } |
| 68 | |
| 69 | /// Key equality function. |
| 70 | bool operator==(const KeyTy &key) const { return key == value; } |
| 71 | |
| 72 | /// Construct a new storage instance. |
| 73 | static TypeAttributeStorage * |
| 74 | construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) { |
| 75 | return new (allocator.allocate<TypeAttributeStorage>()) |
| 76 | TypeAttributeStorage(key); |
| 77 | } |
| 78 | |
| 79 | mlir::Type getType() const { return value; } |
| 80 | |
| 81 | private: |
| 82 | mlir::Type value; |
| 83 | }; |
| 84 | } // namespace fir::detail |
| 85 | |
| 86 | //===----------------------------------------------------------------------===// |
| 87 | // Attributes for SELECT TYPE |
| 88 | //===----------------------------------------------------------------------===// |
| 89 | |
| 90 | ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) { |
| 91 | return Base::get(value.getContext(), value); |
| 92 | } |
| 93 | |
| 94 | mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); } |
| 95 | |
| 96 | SubclassAttr fir::SubclassAttr::get(mlir::Type value) { |
| 97 | return Base::get(value.getContext(), value); |
| 98 | } |
| 99 | |
| 100 | mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); } |
| 101 | |
| 102 | //===----------------------------------------------------------------------===// |
| 103 | // Attributes for SELECT CASE |
| 104 | //===----------------------------------------------------------------------===// |
| 105 | |
| 106 | using AttributeUniquer = mlir::detail::AttributeUniquer; |
| 107 | |
| 108 | ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) { |
| 109 | return AttributeUniquer::get<ClosedIntervalAttr>(ctxt); |
| 110 | } |
| 111 | |
| 112 | UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) { |
| 113 | return AttributeUniquer::get<UpperBoundAttr>(ctxt); |
| 114 | } |
| 115 | |
| 116 | LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) { |
| 117 | return AttributeUniquer::get<LowerBoundAttr>(ctxt); |
| 118 | } |
| 119 | |
| 120 | PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) { |
| 121 | return AttributeUniquer::get<PointIntervalAttr>(ctxt); |
| 122 | } |
| 123 | |
| 124 | //===----------------------------------------------------------------------===// |
| 125 | // RealAttr |
| 126 | //===----------------------------------------------------------------------===// |
| 127 | |
| 128 | RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt, |
| 129 | const RealAttr::ValueType &key) { |
| 130 | return Base::get(ctxt, key); |
| 131 | } |
| 132 | |
| 133 | KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); } |
| 134 | |
| 135 | llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); } |
| 136 | |
| 137 | //===----------------------------------------------------------------------===// |
| 138 | // FIR attribute parsing |
| 139 | //===----------------------------------------------------------------------===// |
| 140 | |
| 141 | static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect, |
| 142 | mlir::DialectAsmParser &parser, |
| 143 | mlir::Type type) { |
| 144 | int kind = 0; |
| 145 | if (parser.parseLess() || parser.parseInteger(result&: kind) || parser.parseComma()) { |
| 146 | parser.emitError(loc: parser.getNameLoc(), message: "expected '<' kind ','" ); |
| 147 | return {}; |
| 148 | } |
| 149 | KindMapping kindMap(dialect->getContext()); |
| 150 | llvm::APFloat value(0.); |
| 151 | if (parser.parseOptionalKeyword(keyword: "i" )) { |
| 152 | // `i` not present, so literal float must be present |
| 153 | double dontCare; |
| 154 | if (parser.parseFloat(result&: dontCare) || parser.parseGreater()) { |
| 155 | parser.emitError(loc: parser.getNameLoc(), message: "expected real constant '>'" ); |
| 156 | return {}; |
| 157 | } |
| 158 | auto fltStr = parser.getFullSymbolSpec() |
| 159 | .drop_until(F: [](char c) { return c == ','; }) |
| 160 | .drop_front() |
| 161 | .drop_while(F: [](char c) { return c == ' ' || c == '\t'; }) |
| 162 | .take_until(F: [](char c) { |
| 163 | return c == '>' || c == ' ' || c == '\t'; |
| 164 | }); |
| 165 | value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr); |
| 166 | } else { |
| 167 | // `i` is present, so literal bitstring (hex) must be present |
| 168 | llvm::StringRef hex; |
| 169 | if (parser.parseKeyword(keyword: &hex) || parser.parseGreater()) { |
| 170 | parser.emitError(loc: parser.getNameLoc(), message: "expected real constant '>'" ); |
| 171 | return {}; |
| 172 | } |
| 173 | const llvm::fltSemantics &sem = kindMap.getFloatSemantics(kind); |
| 174 | unsigned int numBits = llvm::APFloat::semanticsSizeInBits(sem); |
| 175 | auto bits = llvm::APInt(numBits, hex.drop_front(), 16); |
| 176 | value = llvm::APFloat(sem, bits); |
| 177 | } |
| 178 | return RealAttr::get(dialect->getContext(), {kind, value}); |
| 179 | } |
| 180 | |
| 181 | mlir::Attribute fir::FortranVariableFlagsAttr::parse(mlir::AsmParser &parser, |
| 182 | mlir::Type type) { |
| 183 | if (mlir::failed(parser.parseLess())) |
| 184 | return {}; |
| 185 | |
| 186 | fir::FortranVariableFlagsEnum flags = {}; |
| 187 | if (mlir::failed(parser.parseOptionalGreater())) { |
| 188 | auto parseFlags = [&]() -> mlir::ParseResult { |
| 189 | llvm::StringRef elemName; |
| 190 | if (mlir::failed(parser.parseKeyword(&elemName))) |
| 191 | return mlir::failure(); |
| 192 | |
| 193 | auto elem = fir::symbolizeFortranVariableFlagsEnum(elemName); |
| 194 | if (!elem) |
| 195 | return parser.emitError(parser.getNameLoc(), |
| 196 | "Unknown fortran variable attribute: " ) |
| 197 | << elemName; |
| 198 | |
| 199 | flags = flags | *elem; |
| 200 | return mlir::success(); |
| 201 | }; |
| 202 | if (mlir::failed(parser.parseCommaSeparatedList(parseFlags)) || |
| 203 | parser.parseGreater()) |
| 204 | return {}; |
| 205 | } |
| 206 | |
| 207 | return FortranVariableFlagsAttr::get(parser.getContext(), flags); |
| 208 | } |
| 209 | |
| 210 | mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect, |
| 211 | mlir::DialectAsmParser &parser, |
| 212 | mlir::Type type) { |
| 213 | auto loc = parser.getNameLoc(); |
| 214 | llvm::StringRef attrName; |
| 215 | mlir::Attribute attr; |
| 216 | mlir::OptionalParseResult result = |
| 217 | generatedAttributeParser(parser, &attrName, type, attr); |
| 218 | if (result.has_value()) |
| 219 | return attr; |
| 220 | if (attrName.empty()) |
| 221 | return {}; // error reported by generatedAttributeParser |
| 222 | |
| 223 | if (attrName == ExactTypeAttr::getAttrName()) { |
| 224 | mlir::Type type; |
| 225 | if (parser.parseLess() || parser.parseType(result&: type) || parser.parseGreater()) { |
| 226 | parser.emitError(loc, message: "expected a type" ); |
| 227 | return {}; |
| 228 | } |
| 229 | return ExactTypeAttr::get(type); |
| 230 | } |
| 231 | if (attrName == SubclassAttr::getAttrName()) { |
| 232 | mlir::Type type; |
| 233 | if (parser.parseLess() || parser.parseType(result&: type) || parser.parseGreater()) { |
| 234 | parser.emitError(loc, message: "expected a subtype" ); |
| 235 | return {}; |
| 236 | } |
| 237 | return SubclassAttr::get(type); |
| 238 | } |
| 239 | if (attrName == PointIntervalAttr::getAttrName()) |
| 240 | return PointIntervalAttr::get(dialect->getContext()); |
| 241 | if (attrName == LowerBoundAttr::getAttrName()) |
| 242 | return LowerBoundAttr::get(dialect->getContext()); |
| 243 | if (attrName == UpperBoundAttr::getAttrName()) |
| 244 | return UpperBoundAttr::get(dialect->getContext()); |
| 245 | if (attrName == ClosedIntervalAttr::getAttrName()) |
| 246 | return ClosedIntervalAttr::get(dialect->getContext()); |
| 247 | if (attrName == RealAttr::getAttrName()) |
| 248 | return parseFirRealAttr(dialect, parser, type); |
| 249 | |
| 250 | parser.emitError(loc, message: "unknown FIR attribute: " ) << attrName; |
| 251 | return {}; |
| 252 | } |
| 253 | |
| 254 | //===----------------------------------------------------------------------===// |
| 255 | // FIR attribute pretty printer |
| 256 | //===----------------------------------------------------------------------===// |
| 257 | |
| 258 | void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter &printer) const { |
| 259 | printer << "<" ; |
| 260 | printer << fir::stringifyFortranVariableFlagsEnum(this->getFlags()); |
| 261 | printer << ">" ; |
| 262 | } |
| 263 | |
| 264 | void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr, |
| 265 | mlir::DialectAsmPrinter &p) { |
| 266 | auto &os = p.getStream(); |
| 267 | if (auto exact = mlir::dyn_cast<fir::ExactTypeAttr>(attr)) { |
| 268 | os << fir::ExactTypeAttr::getAttrName() << '<'; |
| 269 | p.printType(type: exact.getType()); |
| 270 | os << '>'; |
| 271 | } else if (auto sub = mlir::dyn_cast<fir::SubclassAttr>(attr)) { |
| 272 | os << fir::SubclassAttr::getAttrName() << '<'; |
| 273 | p.printType(type: sub.getType()); |
| 274 | os << '>'; |
| 275 | } else if (mlir::dyn_cast_or_null<fir::PointIntervalAttr>(attr)) { |
| 276 | os << fir::PointIntervalAttr::getAttrName(); |
| 277 | } else if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) { |
| 278 | os << fir::ClosedIntervalAttr::getAttrName(); |
| 279 | } else if (mlir::dyn_cast_or_null<fir::LowerBoundAttr>(attr)) { |
| 280 | os << fir::LowerBoundAttr::getAttrName(); |
| 281 | } else if (mlir::dyn_cast_or_null<fir::UpperBoundAttr>(attr)) { |
| 282 | os << fir::UpperBoundAttr::getAttrName(); |
| 283 | } else if (auto a = mlir::dyn_cast_or_null<fir::RealAttr>(attr)) { |
| 284 | os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x" ; |
| 285 | llvm::SmallString<40> ss; |
| 286 | a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16); |
| 287 | os << ss << '>'; |
| 288 | } else if (mlir::failed(Result: generatedAttributePrinter(attr, p))) { |
| 289 | // don't know how to print the attribute, so use a default |
| 290 | os << "<(unknown attribute)>" ; |
| 291 | } |
| 292 | } |
| 293 | |
| 294 | //===----------------------------------------------------------------------===// |
| 295 | // FIROpsDialect |
| 296 | //===----------------------------------------------------------------------===// |
| 297 | |
| 298 | void FIROpsDialect::registerAttributes() { |
| 299 | addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr, |
| 300 | PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr, |
| 301 | #define GET_ATTRDEF_LIST |
| 302 | #include "flang/Optimizer/Dialect/FIRAttr.cpp.inc" |
| 303 | >(); |
| 304 | } |
| 305 | |