| 1 | //===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the attributes in the CIR dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "clang/CIR/Dialect/IR/CIRDialect.h" |
| 14 | |
| 15 | #include "mlir/IR/DialectImplementation.h" |
| 16 | #include "llvm/ADT/TypeSwitch.h" |
| 17 | |
| 18 | //===-----------------------------------------------------------------===// |
| 19 | // IntLiteral |
| 20 | //===-----------------------------------------------------------------===// |
| 21 | |
| 22 | static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value, |
| 23 | cir::IntTypeInterface ty); |
| 24 | static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, |
| 25 | llvm::APInt &value, |
| 26 | cir::IntTypeInterface ty); |
| 27 | //===-----------------------------------------------------------------===// |
| 28 | // FloatLiteral |
| 29 | //===-----------------------------------------------------------------===// |
| 30 | |
| 31 | static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value, |
| 32 | mlir::Type ty); |
| 33 | static mlir::ParseResult |
| 34 | parseFloatLiteral(mlir::AsmParser &parser, |
| 35 | mlir::FailureOr<llvm::APFloat> &value, |
| 36 | cir::FPTypeInterface fpType); |
| 37 | |
| 38 | static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser, |
| 39 | mlir::IntegerAttr &value); |
| 40 | |
| 41 | static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value); |
| 42 | |
| 43 | #define GET_ATTRDEF_CLASSES |
| 44 | #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc" |
| 45 | |
| 46 | using namespace mlir; |
| 47 | using namespace cir; |
| 48 | |
| 49 | //===----------------------------------------------------------------------===// |
| 50 | // General CIR parsing / printing |
| 51 | //===----------------------------------------------------------------------===// |
| 52 | |
| 53 | Attribute CIRDialect::parseAttribute(DialectAsmParser &parser, |
| 54 | Type type) const { |
| 55 | llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
| 56 | llvm::StringRef mnemonic; |
| 57 | Attribute genAttr; |
| 58 | OptionalParseResult parseResult = |
| 59 | generatedAttributeParser(parser, &mnemonic, type, genAttr); |
| 60 | if (parseResult.has_value()) |
| 61 | return genAttr; |
| 62 | parser.emitError(typeLoc, "unknown attribute in CIR dialect" ); |
| 63 | return Attribute(); |
| 64 | } |
| 65 | |
| 66 | void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { |
| 67 | if (failed(generatedAttributePrinter(attr, os))) |
| 68 | llvm_unreachable("unexpected CIR type kind" ); |
| 69 | } |
| 70 | |
| 71 | //===----------------------------------------------------------------------===// |
| 72 | // OptInfoAttr definitions |
| 73 | //===----------------------------------------------------------------------===// |
| 74 | |
| 75 | LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 76 | unsigned level, unsigned size) { |
| 77 | if (level > 3) |
| 78 | return emitError() |
| 79 | << "optimization level must be between 0 and 3 inclusive" ; |
| 80 | if (size > 2) |
| 81 | return emitError() |
| 82 | << "size optimization level must be between 0 and 2 inclusive" ; |
| 83 | return success(); |
| 84 | } |
| 85 | |
| 86 | //===----------------------------------------------------------------------===// |
| 87 | // ConstPtrAttr definitions |
| 88 | //===----------------------------------------------------------------------===// |
| 89 | |
| 90 | // TODO(CIR): Consider encoding the null value differently and use conditional |
| 91 | // assembly format instead of custom parsing/printing. |
| 92 | static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) { |
| 93 | |
| 94 | if (parser.parseOptionalKeyword(keyword: "null" ).succeeded()) { |
| 95 | value = parser.getBuilder().getI64IntegerAttr(0); |
| 96 | return success(); |
| 97 | } |
| 98 | |
| 99 | return parser.parseAttribute(result&: value); |
| 100 | } |
| 101 | |
| 102 | static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) { |
| 103 | if (!value.getInt()) |
| 104 | p << "null" ; |
| 105 | else |
| 106 | p << value; |
| 107 | } |
| 108 | |
| 109 | //===----------------------------------------------------------------------===// |
| 110 | // IntAttr definitions |
| 111 | //===----------------------------------------------------------------------===// |
| 112 | |
| 113 | template <typename IntT> |
| 114 | static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) { |
| 115 | if constexpr (std::is_signed_v<IntT>) { |
| 116 | return value.getSExtValue() != expectedValue; |
| 117 | } else { |
| 118 | return value.getZExtValue() != expectedValue; |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | template <typename IntT> |
| 123 | static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p, |
| 124 | llvm::APInt &value, |
| 125 | cir::IntTypeInterface ty) { |
| 126 | IntT ivalue; |
| 127 | const bool isSigned = ty.isSigned(); |
| 128 | if (p.parseInteger(ivalue)) |
| 129 | return p.emitError(loc: p.getCurrentLocation(), message: "expected integer value" ); |
| 130 | |
| 131 | value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true); |
| 132 | if (isTooLargeForType(value, ivalue)) |
| 133 | return p.emitError(loc: p.getCurrentLocation(), |
| 134 | message: "integer value too large for the given type" ); |
| 135 | |
| 136 | return success(); |
| 137 | } |
| 138 | |
| 139 | mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value, |
| 140 | cir::IntTypeInterface ty) { |
| 141 | if (ty.isSigned()) |
| 142 | return parseIntLiteralImpl<int64_t>(parser, value, ty); |
| 143 | return parseIntLiteralImpl<uint64_t>(parser, value, ty); |
| 144 | } |
| 145 | |
| 146 | void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value, |
| 147 | cir::IntTypeInterface ty) { |
| 148 | if (ty.isSigned()) |
| 149 | p << value.getSExtValue(); |
| 150 | else |
| 151 | p << value.getZExtValue(); |
| 152 | } |
| 153 | |
| 154 | LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 155 | cir::IntTypeInterface type, llvm::APInt value) { |
| 156 | if (value.getBitWidth() != type.getWidth()) |
| 157 | return emitError() << "type and value bitwidth mismatch: " |
| 158 | << type.getWidth() << " != " << value.getBitWidth(); |
| 159 | return success(); |
| 160 | } |
| 161 | |
| 162 | //===----------------------------------------------------------------------===// |
| 163 | // FPAttr definitions |
| 164 | //===----------------------------------------------------------------------===// |
| 165 | |
| 166 | static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) { |
| 167 | p << value; |
| 168 | } |
| 169 | |
| 170 | static ParseResult parseFloatLiteral(AsmParser &parser, |
| 171 | FailureOr<APFloat> &value, |
| 172 | cir::FPTypeInterface fpType) { |
| 173 | |
| 174 | APFloat parsedValue(0.0); |
| 175 | if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue)) |
| 176 | return failure(); |
| 177 | |
| 178 | value.emplace(args&: parsedValue); |
| 179 | return success(); |
| 180 | } |
| 181 | |
| 182 | FPAttr FPAttr::getZero(Type type) { |
| 183 | return get(type, |
| 184 | APFloat::getZero( |
| 185 | mlir::cast<cir::FPTypeInterface>(type).getFloatSemantics())); |
| 186 | } |
| 187 | |
| 188 | LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 189 | cir::FPTypeInterface fpType, APFloat value) { |
| 190 | if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) != |
| 191 | APFloat::SemanticsToEnum(value.getSemantics())) |
| 192 | return emitError() << "floating-point semantics mismatch" ; |
| 193 | |
| 194 | return success(); |
| 195 | } |
| 196 | |
| 197 | //===----------------------------------------------------------------------===// |
| 198 | // ConstComplexAttr definitions |
| 199 | //===----------------------------------------------------------------------===// |
| 200 | |
| 201 | LogicalResult |
| 202 | ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 203 | cir::ComplexType type, mlir::TypedAttr real, |
| 204 | mlir::TypedAttr imag) { |
| 205 | mlir::Type elemType = type.getElementType(); |
| 206 | if (real.getType() != elemType) |
| 207 | return emitError() |
| 208 | << "type of the real part does not match the complex type" ; |
| 209 | |
| 210 | if (imag.getType() != elemType) |
| 211 | return emitError() |
| 212 | << "type of the imaginary part does not match the complex type" ; |
| 213 | |
| 214 | return success(); |
| 215 | } |
| 216 | |
| 217 | //===----------------------------------------------------------------------===// |
| 218 | // CIR ConstArrayAttr |
| 219 | //===----------------------------------------------------------------------===// |
| 220 | |
| 221 | LogicalResult |
| 222 | ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type, |
| 223 | Attribute elts, int trailingZerosNum) { |
| 224 | |
| 225 | if (!(mlir::isa<ArrayAttr, StringAttr>(Val: elts))) |
| 226 | return emitError() << "constant array expects ArrayAttr or StringAttr" ; |
| 227 | |
| 228 | if (auto strAttr = mlir::dyn_cast<StringAttr>(Val&: elts)) { |
| 229 | const auto arrayTy = mlir::cast<ArrayType>(type); |
| 230 | const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType()); |
| 231 | |
| 232 | // TODO: add CIR type for char. |
| 233 | if (!intTy || intTy.getWidth() != 8) |
| 234 | return emitError() |
| 235 | << "constant array element for string literals expects " |
| 236 | "!cir.int<u, 8> element type" ; |
| 237 | return success(); |
| 238 | } |
| 239 | |
| 240 | assert(mlir::isa<ArrayAttr>(elts)); |
| 241 | const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(Val&: elts); |
| 242 | const auto arrayTy = mlir::cast<ArrayType>(type); |
| 243 | |
| 244 | // Make sure both number of elements and subelement types match type. |
| 245 | if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum) |
| 246 | return emitError() << "constant array size should match type size" ; |
| 247 | return success(); |
| 248 | } |
| 249 | |
| 250 | Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) { |
| 251 | mlir::FailureOr<Type> resultTy; |
| 252 | mlir::FailureOr<Attribute> resultVal; |
| 253 | |
| 254 | // Parse literal '<' |
| 255 | if (parser.parseLess()) |
| 256 | return {}; |
| 257 | |
| 258 | // Parse variable 'value' |
| 259 | resultVal = FieldParser<Attribute>::parse(parser); |
| 260 | if (failed(Result: resultVal)) { |
| 261 | parser.emitError( |
| 262 | loc: parser.getCurrentLocation(), |
| 263 | message: "failed to parse ConstArrayAttr parameter 'value' which is " |
| 264 | "to be a `Attribute`" ); |
| 265 | return {}; |
| 266 | } |
| 267 | |
| 268 | // ArrayAttrrs have per-element type, not the type of the array... |
| 269 | if (mlir::isa<ArrayAttr>(Val: *resultVal)) { |
| 270 | // Array has implicit type: infer from const array type. |
| 271 | if (parser.parseOptionalColon().failed()) { |
| 272 | resultTy = type; |
| 273 | } else { // Array has explicit type: parse it. |
| 274 | resultTy = FieldParser<Type>::parse(parser); |
| 275 | if (failed(Result: resultTy)) { |
| 276 | parser.emitError( |
| 277 | loc: parser.getCurrentLocation(), |
| 278 | message: "failed to parse ConstArrayAttr parameter 'type' which is " |
| 279 | "to be a `::mlir::Type`" ); |
| 280 | return {}; |
| 281 | } |
| 282 | } |
| 283 | } else { |
| 284 | auto ta = mlir::cast<TypedAttr>(Val&: *resultVal); |
| 285 | resultTy = ta.getType(); |
| 286 | if (mlir::isa<mlir::NoneType>(Val: *resultTy)) { |
| 287 | parser.emitError(loc: parser.getCurrentLocation(), |
| 288 | message: "expected type declaration for string literal" ); |
| 289 | return {}; |
| 290 | } |
| 291 | } |
| 292 | |
| 293 | unsigned zeros = 0; |
| 294 | if (parser.parseOptionalComma().succeeded()) { |
| 295 | if (parser.parseOptionalKeyword(keyword: "trailing_zeros" ).succeeded()) { |
| 296 | unsigned typeSize = |
| 297 | mlir::cast<cir::ArrayType>(resultTy.value()).getSize(); |
| 298 | mlir::Attribute elts = resultVal.value(); |
| 299 | if (auto str = mlir::dyn_cast<mlir::StringAttr>(Val&: elts)) |
| 300 | zeros = typeSize - str.size(); |
| 301 | else |
| 302 | zeros = typeSize - mlir::cast<mlir::ArrayAttr>(Val&: elts).size(); |
| 303 | } else { |
| 304 | return {}; |
| 305 | } |
| 306 | } |
| 307 | |
| 308 | // Parse literal '>' |
| 309 | if (parser.parseGreater()) |
| 310 | return {}; |
| 311 | |
| 312 | return parser.getChecked<ConstArrayAttr>( |
| 313 | loc: parser.getCurrentLocation(), params: parser.getContext(), params&: resultTy.value(), |
| 314 | params&: resultVal.value(), params&: zeros); |
| 315 | } |
| 316 | |
| 317 | void ConstArrayAttr::print(AsmPrinter &printer) const { |
| 318 | printer << "<" ; |
| 319 | printer.printStrippedAttrOrType(getElts()); |
| 320 | if (getTrailingZerosNum()) |
| 321 | printer << ", trailing_zeros" ; |
| 322 | printer << ">" ; |
| 323 | } |
| 324 | |
| 325 | //===----------------------------------------------------------------------===// |
| 326 | // CIR ConstVectorAttr |
| 327 | //===----------------------------------------------------------------------===// |
| 328 | |
| 329 | LogicalResult |
| 330 | cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| 331 | Type type, ArrayAttr elts) { |
| 332 | |
| 333 | if (!mlir::isa<cir::VectorType>(type)) |
| 334 | return emitError() << "type of cir::ConstVectorAttr is not a " |
| 335 | "cir::VectorType: " |
| 336 | << type; |
| 337 | |
| 338 | const auto vecType = mlir::cast<cir::VectorType>(type); |
| 339 | |
| 340 | if (vecType.getSize() != elts.size()) |
| 341 | return emitError() |
| 342 | << "number of constant elements should match vector size" ; |
| 343 | |
| 344 | // Check if the types of the elements match |
| 345 | LogicalResult elementTypeCheck = success(); |
| 346 | elts.walkImmediateSubElements( |
| 347 | [&](Attribute element) { |
| 348 | if (elementTypeCheck.failed()) { |
| 349 | // An earlier element didn't match |
| 350 | return; |
| 351 | } |
| 352 | auto typedElement = mlir::dyn_cast<TypedAttr>(element); |
| 353 | if (!typedElement || |
| 354 | typedElement.getType() != vecType.getElementType()) { |
| 355 | elementTypeCheck = failure(); |
| 356 | emitError() << "constant type should match vector element type" ; |
| 357 | } |
| 358 | }, |
| 359 | [&](Type) {}); |
| 360 | |
| 361 | return elementTypeCheck; |
| 362 | } |
| 363 | |
| 364 | //===----------------------------------------------------------------------===// |
| 365 | // CIR Dialect |
| 366 | //===----------------------------------------------------------------------===// |
| 367 | |
| 368 | void CIRDialect::registerAttributes() { |
| 369 | addAttributes< |
| 370 | #define GET_ATTRDEF_LIST |
| 371 | #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc" |
| 372 | >(); |
| 373 | } |
| 374 | |