| 1 | //===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the attributes in the CIR dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "clang/CIR/Dialect/IR/CIRDialect.h" |
| 14 | |
| 15 | #include "mlir/IR/DialectImplementation.h" |
| 16 | #include "llvm/ADT/TypeSwitch.h" |
| 17 | |
| 18 | static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value, |
| 19 | mlir::Type ty); |
| 20 | static mlir::ParseResult |
| 21 | parseFloatLiteral(mlir::AsmParser &parser, |
| 22 | mlir::FailureOr<llvm::APFloat> &value, |
| 23 | cir::CIRFPTypeInterface fpType); |
| 24 | |
| 25 | static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser, |
| 26 | mlir::IntegerAttr &value); |
| 27 | |
| 28 | static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value); |
| 29 | |
| 30 | #define GET_ATTRDEF_CLASSES |
| 31 | #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc" |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace cir; |
| 35 | |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | // General CIR parsing / printing |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | |
| 40 | Attribute CIRDialect::parseAttribute(DialectAsmParser &parser, |
| 41 | Type type) const { |
| 42 | llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
| 43 | llvm::StringRef mnemonic; |
| 44 | Attribute genAttr; |
| 45 | OptionalParseResult parseResult = |
| 46 | generatedAttributeParser(parser, &mnemonic, type, genAttr); |
| 47 | if (parseResult.has_value()) |
| 48 | return genAttr; |
| 49 | parser.emitError(typeLoc, "unknown attribute in CIR dialect" ); |
| 50 | return Attribute(); |
| 51 | } |
| 52 | |
| 53 | void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { |
| 54 | if (failed(generatedAttributePrinter(attr, os))) |
| 55 | llvm_unreachable("unexpected CIR type kind" ); |
| 56 | } |
| 57 | |
| 58 | //===----------------------------------------------------------------------===// |
| 59 | // ConstPtrAttr definitions |
| 60 | //===----------------------------------------------------------------------===// |
| 61 | |
| 62 | // TODO(CIR): Consider encoding the null value differently and use conditional |
| 63 | // assembly format instead of custom parsing/printing. |
| 64 | static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) { |
| 65 | |
| 66 | if (parser.parseOptionalKeyword(keyword: "null" ).succeeded()) { |
| 67 | value = parser.getBuilder().getI64IntegerAttr(0); |
| 68 | return success(); |
| 69 | } |
| 70 | |
| 71 | return parser.parseAttribute(value); |
| 72 | } |
| 73 | |
| 74 | static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) { |
| 75 | if (!value.getInt()) |
| 76 | p << "null" ; |
| 77 | else |
| 78 | p << value; |
| 79 | } |
| 80 | |
| 81 | //===----------------------------------------------------------------------===// |
| 82 | // IntAttr definitions |
| 83 | //===----------------------------------------------------------------------===// |
| 84 | |
| 85 | Attribute IntAttr::parse(AsmParser &parser, Type odsType) { |
| 86 | mlir::APInt apValue; |
| 87 | |
| 88 | if (!mlir::isa<IntType>(odsType)) |
| 89 | return {}; |
| 90 | auto type = mlir::cast<IntType>(odsType); |
| 91 | |
| 92 | // Consume the '<' symbol. |
| 93 | if (parser.parseLess()) |
| 94 | return {}; |
| 95 | |
| 96 | // Fetch arbitrary precision integer value. |
| 97 | if (type.isSigned()) { |
| 98 | int64_t value = 0; |
| 99 | if (parser.parseInteger(value)) { |
| 100 | parser.emitError(parser.getCurrentLocation(), "expected integer value" ); |
| 101 | } else { |
| 102 | apValue = mlir::APInt(type.getWidth(), value, type.isSigned(), |
| 103 | /*implicitTrunc=*/true); |
| 104 | if (apValue.getSExtValue() != value) |
| 105 | parser.emitError(parser.getCurrentLocation(), |
| 106 | "integer value too large for the given type" ); |
| 107 | } |
| 108 | } else { |
| 109 | uint64_t value = 0; |
| 110 | if (parser.parseInteger(value)) { |
| 111 | parser.emitError(parser.getCurrentLocation(), "expected integer value" ); |
| 112 | } else { |
| 113 | apValue = mlir::APInt(type.getWidth(), value, type.isSigned(), |
| 114 | /*implicitTrunc=*/true); |
| 115 | if (apValue.getZExtValue() != value) |
| 116 | parser.emitError(parser.getCurrentLocation(), |
| 117 | "integer value too large for the given type" ); |
| 118 | } |
| 119 | } |
| 120 | |
| 121 | // Consume the '>' symbol. |
| 122 | if (parser.parseGreater()) |
| 123 | return {}; |
| 124 | |
| 125 | return IntAttr::get(type, apValue); |
| 126 | } |
| 127 | |
| 128 | void IntAttr::print(AsmPrinter &printer) const { |
| 129 | auto type = mlir::cast<IntType>(getType()); |
| 130 | printer << '<'; |
| 131 | if (type.isSigned()) |
| 132 | printer << getSInt(); |
| 133 | else |
| 134 | printer << getUInt(); |
| 135 | printer << '>'; |
| 136 | } |
| 137 | |
| 138 | LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 139 | Type type, APInt value) { |
| 140 | if (!mlir::isa<IntType>(type)) |
| 141 | return emitError() << "expected 'simple.int' type" ; |
| 142 | |
| 143 | auto intType = mlir::cast<IntType>(type); |
| 144 | if (value.getBitWidth() != intType.getWidth()) |
| 145 | return emitError() << "type and value bitwidth mismatch: " |
| 146 | << intType.getWidth() << " != " << value.getBitWidth(); |
| 147 | |
| 148 | return success(); |
| 149 | } |
| 150 | |
| 151 | //===----------------------------------------------------------------------===// |
| 152 | // FPAttr definitions |
| 153 | //===----------------------------------------------------------------------===// |
| 154 | |
| 155 | static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) { |
| 156 | p << value; |
| 157 | } |
| 158 | |
| 159 | static ParseResult parseFloatLiteral(AsmParser &parser, |
| 160 | FailureOr<APFloat> &value, |
| 161 | CIRFPTypeInterface fpType) { |
| 162 | |
| 163 | APFloat parsedValue(0.0); |
| 164 | if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue)) |
| 165 | return failure(); |
| 166 | |
| 167 | value.emplace(args&: parsedValue); |
| 168 | return success(); |
| 169 | } |
| 170 | |
| 171 | FPAttr FPAttr::getZero(Type type) { |
| 172 | return get(type, |
| 173 | APFloat::getZero( |
| 174 | mlir::cast<CIRFPTypeInterface>(type).getFloatSemantics())); |
| 175 | } |
| 176 | |
| 177 | LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 178 | CIRFPTypeInterface fpType, APFloat value) { |
| 179 | if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) != |
| 180 | APFloat::SemanticsToEnum(value.getSemantics())) |
| 181 | return emitError() << "floating-point semantics mismatch" ; |
| 182 | |
| 183 | return success(); |
| 184 | } |
| 185 | |
| 186 | //===----------------------------------------------------------------------===// |
| 187 | // ConstComplexAttr definitions |
| 188 | //===----------------------------------------------------------------------===// |
| 189 | |
| 190 | LogicalResult |
| 191 | ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 192 | cir::ComplexType type, mlir::TypedAttr real, |
| 193 | mlir::TypedAttr imag) { |
| 194 | mlir::Type elemType = type.getElementType(); |
| 195 | if (real.getType() != elemType) |
| 196 | return emitError() |
| 197 | << "type of the real part does not match the complex type" ; |
| 198 | |
| 199 | if (imag.getType() != elemType) |
| 200 | return emitError() |
| 201 | << "type of the imaginary part does not match the complex type" ; |
| 202 | |
| 203 | return success(); |
| 204 | } |
| 205 | |
| 206 | //===----------------------------------------------------------------------===// |
| 207 | // CIR ConstArrayAttr |
| 208 | //===----------------------------------------------------------------------===// |
| 209 | |
| 210 | LogicalResult |
| 211 | ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type, |
| 212 | Attribute elts, int trailingZerosNum) { |
| 213 | |
| 214 | if (!(mlir::isa<ArrayAttr, StringAttr>(elts))) |
| 215 | return emitError() << "constant array expects ArrayAttr or StringAttr" ; |
| 216 | |
| 217 | if (auto strAttr = mlir::dyn_cast<StringAttr>(elts)) { |
| 218 | const auto arrayTy = mlir::cast<ArrayType>(type); |
| 219 | const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType()); |
| 220 | |
| 221 | // TODO: add CIR type for char. |
| 222 | if (!intTy || intTy.getWidth() != 8) |
| 223 | return emitError() |
| 224 | << "constant array element for string literals expects " |
| 225 | "!cir.int<u, 8> element type" ; |
| 226 | return success(); |
| 227 | } |
| 228 | |
| 229 | assert(mlir::isa<ArrayAttr>(elts)); |
| 230 | const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(elts); |
| 231 | const auto arrayTy = mlir::cast<ArrayType>(type); |
| 232 | |
| 233 | // Make sure both number of elements and subelement types match type. |
| 234 | if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum) |
| 235 | return emitError() << "constant array size should match type size" ; |
| 236 | return success(); |
| 237 | } |
| 238 | |
| 239 | Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) { |
| 240 | mlir::FailureOr<Type> resultTy; |
| 241 | mlir::FailureOr<Attribute> resultVal; |
| 242 | |
| 243 | // Parse literal '<' |
| 244 | if (parser.parseLess()) |
| 245 | return {}; |
| 246 | |
| 247 | // Parse variable 'value' |
| 248 | resultVal = FieldParser<Attribute>::parse(parser); |
| 249 | if (failed(resultVal)) { |
| 250 | parser.emitError( |
| 251 | parser.getCurrentLocation(), |
| 252 | "failed to parse ConstArrayAttr parameter 'value' which is " |
| 253 | "to be a `Attribute`" ); |
| 254 | return {}; |
| 255 | } |
| 256 | |
| 257 | // ArrayAttrrs have per-element type, not the type of the array... |
| 258 | if (mlir::isa<ArrayAttr>(*resultVal)) { |
| 259 | // Array has implicit type: infer from const array type. |
| 260 | if (parser.parseOptionalColon().failed()) { |
| 261 | resultTy = type; |
| 262 | } else { // Array has explicit type: parse it. |
| 263 | resultTy = FieldParser<Type>::parse(parser); |
| 264 | if (failed(resultTy)) { |
| 265 | parser.emitError( |
| 266 | parser.getCurrentLocation(), |
| 267 | "failed to parse ConstArrayAttr parameter 'type' which is " |
| 268 | "to be a `::mlir::Type`" ); |
| 269 | return {}; |
| 270 | } |
| 271 | } |
| 272 | } else { |
| 273 | auto ta = mlir::cast<TypedAttr>(*resultVal); |
| 274 | resultTy = ta.getType(); |
| 275 | if (mlir::isa<mlir::NoneType>(*resultTy)) { |
| 276 | parser.emitError(parser.getCurrentLocation(), |
| 277 | "expected type declaration for string literal" ); |
| 278 | return {}; |
| 279 | } |
| 280 | } |
| 281 | |
| 282 | unsigned zeros = 0; |
| 283 | if (parser.parseOptionalComma().succeeded()) { |
| 284 | if (parser.parseOptionalKeyword("trailing_zeros" ).succeeded()) { |
| 285 | unsigned typeSize = |
| 286 | mlir::cast<cir::ArrayType>(resultTy.value()).getSize(); |
| 287 | mlir::Attribute elts = resultVal.value(); |
| 288 | if (auto str = mlir::dyn_cast<mlir::StringAttr>(elts)) |
| 289 | zeros = typeSize - str.size(); |
| 290 | else |
| 291 | zeros = typeSize - mlir::cast<mlir::ArrayAttr>(elts).size(); |
| 292 | } else { |
| 293 | return {}; |
| 294 | } |
| 295 | } |
| 296 | |
| 297 | // Parse literal '>' |
| 298 | if (parser.parseGreater()) |
| 299 | return {}; |
| 300 | |
| 301 | return parser.getChecked<ConstArrayAttr>( |
| 302 | parser.getCurrentLocation(), parser.getContext(), resultTy.value(), |
| 303 | resultVal.value(), zeros); |
| 304 | } |
| 305 | |
| 306 | void ConstArrayAttr::print(AsmPrinter &printer) const { |
| 307 | printer << "<" ; |
| 308 | printer.printStrippedAttrOrType(getElts()); |
| 309 | if (getTrailingZerosNum()) |
| 310 | printer << ", trailing_zeros" ; |
| 311 | printer << ">" ; |
| 312 | } |
| 313 | |
| 314 | //===----------------------------------------------------------------------===// |
| 315 | // CIR ConstVectorAttr |
| 316 | //===----------------------------------------------------------------------===// |
| 317 | |
| 318 | LogicalResult |
| 319 | cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 320 | Type type, ArrayAttr elts) { |
| 321 | |
| 322 | if (!mlir::isa<cir::VectorType>(type)) |
| 323 | return emitError() << "type of cir::ConstVectorAttr is not a " |
| 324 | "cir::VectorType: " |
| 325 | << type; |
| 326 | |
| 327 | const auto vecType = mlir::cast<cir::VectorType>(type); |
| 328 | |
| 329 | if (vecType.getSize() != elts.size()) |
| 330 | return emitError() |
| 331 | << "number of constant elements should match vector size" ; |
| 332 | |
| 333 | // Check if the types of the elements match |
| 334 | LogicalResult elementTypeCheck = success(); |
| 335 | elts.walkImmediateSubElements( |
| 336 | [&](Attribute element) { |
| 337 | if (elementTypeCheck.failed()) { |
| 338 | // An earlier element didn't match |
| 339 | return; |
| 340 | } |
| 341 | auto typedElement = mlir::dyn_cast<TypedAttr>(element); |
| 342 | if (!typedElement || |
| 343 | typedElement.getType() != vecType.getElementType()) { |
| 344 | elementTypeCheck = failure(); |
| 345 | emitError() << "constant type should match vector element type" ; |
| 346 | } |
| 347 | }, |
| 348 | [&](Type) {}); |
| 349 | |
| 350 | return elementTypeCheck; |
| 351 | } |
| 352 | |
| 353 | //===----------------------------------------------------------------------===// |
| 354 | // CIR Dialect |
| 355 | //===----------------------------------------------------------------------===// |
| 356 | |
| 357 | void CIRDialect::registerAttributes() { |
| 358 | addAttributes< |
| 359 | #define GET_ATTRDEF_LIST |
| 360 | #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc" |
| 361 | >(); |
| 362 | } |
| 363 | |