1//===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the attributes in the CIR dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "clang/CIR/Dialect/IR/CIRDialect.h"
14
15#include "mlir/IR/DialectImplementation.h"
16#include "llvm/ADT/TypeSwitch.h"
17
18//===-----------------------------------------------------------------===//
19// IntLiteral
20//===-----------------------------------------------------------------===//
21
22static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
23 cir::IntTypeInterface ty);
24static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
25 llvm::APInt &value,
26 cir::IntTypeInterface ty);
27//===-----------------------------------------------------------------===//
28// FloatLiteral
29//===-----------------------------------------------------------------===//
30
31static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
32 mlir::Type ty);
33static mlir::ParseResult
34parseFloatLiteral(mlir::AsmParser &parser,
35 mlir::FailureOr<llvm::APFloat> &value,
36 cir::FPTypeInterface fpType);
37
38static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
39 mlir::IntegerAttr &value);
40
41static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value);
42
43#define GET_ATTRDEF_CLASSES
44#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
45
46using namespace mlir;
47using namespace cir;
48
49//===----------------------------------------------------------------------===//
50// General CIR parsing / printing
51//===----------------------------------------------------------------------===//
52
53Attribute CIRDialect::parseAttribute(DialectAsmParser &parser,
54 Type type) const {
55 llvm::SMLoc typeLoc = parser.getCurrentLocation();
56 llvm::StringRef mnemonic;
57 Attribute genAttr;
58 OptionalParseResult parseResult =
59 generatedAttributeParser(parser, &mnemonic, type, genAttr);
60 if (parseResult.has_value())
61 return genAttr;
62 parser.emitError(typeLoc, "unknown attribute in CIR dialect");
63 return Attribute();
64}
65
66void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
67 if (failed(generatedAttributePrinter(attr, os)))
68 llvm_unreachable("unexpected CIR type kind");
69}
70
71//===----------------------------------------------------------------------===//
72// OptInfoAttr definitions
73//===----------------------------------------------------------------------===//
74
75LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
76 unsigned level, unsigned size) {
77 if (level > 3)
78 return emitError()
79 << "optimization level must be between 0 and 3 inclusive";
80 if (size > 2)
81 return emitError()
82 << "size optimization level must be between 0 and 2 inclusive";
83 return success();
84}
85
86//===----------------------------------------------------------------------===//
87// ConstPtrAttr definitions
88//===----------------------------------------------------------------------===//
89
90// TODO(CIR): Consider encoding the null value differently and use conditional
91// assembly format instead of custom parsing/printing.
92static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
93
94 if (parser.parseOptionalKeyword(keyword: "null").succeeded()) {
95 value = parser.getBuilder().getI64IntegerAttr(0);
96 return success();
97 }
98
99 return parser.parseAttribute(result&: value);
100}
101
102static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
103 if (!value.getInt())
104 p << "null";
105 else
106 p << value;
107}
108
109//===----------------------------------------------------------------------===//
110// IntAttr definitions
111//===----------------------------------------------------------------------===//
112
113template <typename IntT>
114static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) {
115 if constexpr (std::is_signed_v<IntT>) {
116 return value.getSExtValue() != expectedValue;
117 } else {
118 return value.getZExtValue() != expectedValue;
119 }
120}
121
122template <typename IntT>
123static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p,
124 llvm::APInt &value,
125 cir::IntTypeInterface ty) {
126 IntT ivalue;
127 const bool isSigned = ty.isSigned();
128 if (p.parseInteger(ivalue))
129 return p.emitError(loc: p.getCurrentLocation(), message: "expected integer value");
130
131 value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
132 if (isTooLargeForType(value, ivalue))
133 return p.emitError(loc: p.getCurrentLocation(),
134 message: "integer value too large for the given type");
135
136 return success();
137}
138
139mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
140 cir::IntTypeInterface ty) {
141 if (ty.isSigned())
142 return parseIntLiteralImpl<int64_t>(parser, value, ty);
143 return parseIntLiteralImpl<uint64_t>(parser, value, ty);
144}
145
146void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
147 cir::IntTypeInterface ty) {
148 if (ty.isSigned())
149 p << value.getSExtValue();
150 else
151 p << value.getZExtValue();
152}
153
154LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
155 cir::IntTypeInterface type, llvm::APInt value) {
156 if (value.getBitWidth() != type.getWidth())
157 return emitError() << "type and value bitwidth mismatch: "
158 << type.getWidth() << " != " << value.getBitWidth();
159 return success();
160}
161
162//===----------------------------------------------------------------------===//
163// FPAttr definitions
164//===----------------------------------------------------------------------===//
165
166static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
167 p << value;
168}
169
170static ParseResult parseFloatLiteral(AsmParser &parser,
171 FailureOr<APFloat> &value,
172 cir::FPTypeInterface fpType) {
173
174 APFloat parsedValue(0.0);
175 if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
176 return failure();
177
178 value.emplace(args&: parsedValue);
179 return success();
180}
181
182FPAttr FPAttr::getZero(Type type) {
183 return get(type,
184 APFloat::getZero(
185 mlir::cast<cir::FPTypeInterface>(type).getFloatSemantics()));
186}
187
188LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
189 cir::FPTypeInterface fpType, APFloat value) {
190 if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
191 APFloat::SemanticsToEnum(value.getSemantics()))
192 return emitError() << "floating-point semantics mismatch";
193
194 return success();
195}
196
197//===----------------------------------------------------------------------===//
198// ConstComplexAttr definitions
199//===----------------------------------------------------------------------===//
200
201LogicalResult
202ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
203 cir::ComplexType type, mlir::TypedAttr real,
204 mlir::TypedAttr imag) {
205 mlir::Type elemType = type.getElementType();
206 if (real.getType() != elemType)
207 return emitError()
208 << "type of the real part does not match the complex type";
209
210 if (imag.getType() != elemType)
211 return emitError()
212 << "type of the imaginary part does not match the complex type";
213
214 return success();
215}
216
217//===----------------------------------------------------------------------===//
218// CIR ConstArrayAttr
219//===----------------------------------------------------------------------===//
220
221LogicalResult
222ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type,
223 Attribute elts, int trailingZerosNum) {
224
225 if (!(mlir::isa<ArrayAttr, StringAttr>(Val: elts)))
226 return emitError() << "constant array expects ArrayAttr or StringAttr";
227
228 if (auto strAttr = mlir::dyn_cast<StringAttr>(Val&: elts)) {
229 const auto arrayTy = mlir::cast<ArrayType>(type);
230 const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType());
231
232 // TODO: add CIR type for char.
233 if (!intTy || intTy.getWidth() != 8)
234 return emitError()
235 << "constant array element for string literals expects "
236 "!cir.int<u, 8> element type";
237 return success();
238 }
239
240 assert(mlir::isa<ArrayAttr>(elts));
241 const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(Val&: elts);
242 const auto arrayTy = mlir::cast<ArrayType>(type);
243
244 // Make sure both number of elements and subelement types match type.
245 if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum)
246 return emitError() << "constant array size should match type size";
247 return success();
248}
249
250Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) {
251 mlir::FailureOr<Type> resultTy;
252 mlir::FailureOr<Attribute> resultVal;
253
254 // Parse literal '<'
255 if (parser.parseLess())
256 return {};
257
258 // Parse variable 'value'
259 resultVal = FieldParser<Attribute>::parse(parser);
260 if (failed(Result: resultVal)) {
261 parser.emitError(
262 loc: parser.getCurrentLocation(),
263 message: "failed to parse ConstArrayAttr parameter 'value' which is "
264 "to be a `Attribute`");
265 return {};
266 }
267
268 // ArrayAttrrs have per-element type, not the type of the array...
269 if (mlir::isa<ArrayAttr>(Val: *resultVal)) {
270 // Array has implicit type: infer from const array type.
271 if (parser.parseOptionalColon().failed()) {
272 resultTy = type;
273 } else { // Array has explicit type: parse it.
274 resultTy = FieldParser<Type>::parse(parser);
275 if (failed(Result: resultTy)) {
276 parser.emitError(
277 loc: parser.getCurrentLocation(),
278 message: "failed to parse ConstArrayAttr parameter 'type' which is "
279 "to be a `::mlir::Type`");
280 return {};
281 }
282 }
283 } else {
284 auto ta = mlir::cast<TypedAttr>(Val&: *resultVal);
285 resultTy = ta.getType();
286 if (mlir::isa<mlir::NoneType>(Val: *resultTy)) {
287 parser.emitError(loc: parser.getCurrentLocation(),
288 message: "expected type declaration for string literal");
289 return {};
290 }
291 }
292
293 unsigned zeros = 0;
294 if (parser.parseOptionalComma().succeeded()) {
295 if (parser.parseOptionalKeyword(keyword: "trailing_zeros").succeeded()) {
296 unsigned typeSize =
297 mlir::cast<cir::ArrayType>(resultTy.value()).getSize();
298 mlir::Attribute elts = resultVal.value();
299 if (auto str = mlir::dyn_cast<mlir::StringAttr>(Val&: elts))
300 zeros = typeSize - str.size();
301 else
302 zeros = typeSize - mlir::cast<mlir::ArrayAttr>(Val&: elts).size();
303 } else {
304 return {};
305 }
306 }
307
308 // Parse literal '>'
309 if (parser.parseGreater())
310 return {};
311
312 return parser.getChecked<ConstArrayAttr>(
313 loc: parser.getCurrentLocation(), params: parser.getContext(), params&: resultTy.value(),
314 params&: resultVal.value(), params&: zeros);
315}
316
317void ConstArrayAttr::print(AsmPrinter &printer) const {
318 printer << "<";
319 printer.printStrippedAttrOrType(getElts());
320 if (getTrailingZerosNum())
321 printer << ", trailing_zeros";
322 printer << ">";
323}
324
325//===----------------------------------------------------------------------===//
326// CIR ConstVectorAttr
327//===----------------------------------------------------------------------===//
328
329LogicalResult
330cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError,
331 Type type, ArrayAttr elts) {
332
333 if (!mlir::isa<cir::VectorType>(type))
334 return emitError() << "type of cir::ConstVectorAttr is not a "
335 "cir::VectorType: "
336 << type;
337
338 const auto vecType = mlir::cast<cir::VectorType>(type);
339
340 if (vecType.getSize() != elts.size())
341 return emitError()
342 << "number of constant elements should match vector size";
343
344 // Check if the types of the elements match
345 LogicalResult elementTypeCheck = success();
346 elts.walkImmediateSubElements(
347 [&](Attribute element) {
348 if (elementTypeCheck.failed()) {
349 // An earlier element didn't match
350 return;
351 }
352 auto typedElement = mlir::dyn_cast<TypedAttr>(element);
353 if (!typedElement ||
354 typedElement.getType() != vecType.getElementType()) {
355 elementTypeCheck = failure();
356 emitError() << "constant type should match vector element type";
357 }
358 },
359 [&](Type) {});
360
361 return elementTypeCheck;
362}
363
364//===----------------------------------------------------------------------===//
365// CIR Dialect
366//===----------------------------------------------------------------------===//
367
368void CIRDialect::registerAttributes() {
369 addAttributes<
370#define GET_ATTRDEF_LIST
371#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
372 >();
373}
374

source code of clang/lib/CIR/Dialect/IR/CIRAttrs.cpp