1//===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the attributes in the CIR dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "clang/CIR/Dialect/IR/CIRDialect.h"
14
15#include "mlir/IR/DialectImplementation.h"
16#include "llvm/ADT/TypeSwitch.h"
17
18static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
19 mlir::Type ty);
20static mlir::ParseResult
21parseFloatLiteral(mlir::AsmParser &parser,
22 mlir::FailureOr<llvm::APFloat> &value,
23 cir::CIRFPTypeInterface fpType);
24
25static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
26 mlir::IntegerAttr &value);
27
28static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value);
29
30#define GET_ATTRDEF_CLASSES
31#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
32
33using namespace mlir;
34using namespace cir;
35
36//===----------------------------------------------------------------------===//
37// General CIR parsing / printing
38//===----------------------------------------------------------------------===//
39
40Attribute CIRDialect::parseAttribute(DialectAsmParser &parser,
41 Type type) const {
42 llvm::SMLoc typeLoc = parser.getCurrentLocation();
43 llvm::StringRef mnemonic;
44 Attribute genAttr;
45 OptionalParseResult parseResult =
46 generatedAttributeParser(parser, &mnemonic, type, genAttr);
47 if (parseResult.has_value())
48 return genAttr;
49 parser.emitError(typeLoc, "unknown attribute in CIR dialect");
50 return Attribute();
51}
52
53void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
54 if (failed(generatedAttributePrinter(attr, os)))
55 llvm_unreachable("unexpected CIR type kind");
56}
57
58//===----------------------------------------------------------------------===//
59// ConstPtrAttr definitions
60//===----------------------------------------------------------------------===//
61
62// TODO(CIR): Consider encoding the null value differently and use conditional
63// assembly format instead of custom parsing/printing.
64static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
65
66 if (parser.parseOptionalKeyword(keyword: "null").succeeded()) {
67 value = parser.getBuilder().getI64IntegerAttr(0);
68 return success();
69 }
70
71 return parser.parseAttribute(value);
72}
73
74static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
75 if (!value.getInt())
76 p << "null";
77 else
78 p << value;
79}
80
81//===----------------------------------------------------------------------===//
82// IntAttr definitions
83//===----------------------------------------------------------------------===//
84
85Attribute IntAttr::parse(AsmParser &parser, Type odsType) {
86 mlir::APInt apValue;
87
88 if (!mlir::isa<IntType>(odsType))
89 return {};
90 auto type = mlir::cast<IntType>(odsType);
91
92 // Consume the '<' symbol.
93 if (parser.parseLess())
94 return {};
95
96 // Fetch arbitrary precision integer value.
97 if (type.isSigned()) {
98 int64_t value = 0;
99 if (parser.parseInteger(value)) {
100 parser.emitError(parser.getCurrentLocation(), "expected integer value");
101 } else {
102 apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
103 /*implicitTrunc=*/true);
104 if (apValue.getSExtValue() != value)
105 parser.emitError(parser.getCurrentLocation(),
106 "integer value too large for the given type");
107 }
108 } else {
109 uint64_t value = 0;
110 if (parser.parseInteger(value)) {
111 parser.emitError(parser.getCurrentLocation(), "expected integer value");
112 } else {
113 apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
114 /*implicitTrunc=*/true);
115 if (apValue.getZExtValue() != value)
116 parser.emitError(parser.getCurrentLocation(),
117 "integer value too large for the given type");
118 }
119 }
120
121 // Consume the '>' symbol.
122 if (parser.parseGreater())
123 return {};
124
125 return IntAttr::get(type, apValue);
126}
127
128void IntAttr::print(AsmPrinter &printer) const {
129 auto type = mlir::cast<IntType>(getType());
130 printer << '<';
131 if (type.isSigned())
132 printer << getSInt();
133 else
134 printer << getUInt();
135 printer << '>';
136}
137
138LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
139 Type type, APInt value) {
140 if (!mlir::isa<IntType>(type))
141 return emitError() << "expected 'simple.int' type";
142
143 auto intType = mlir::cast<IntType>(type);
144 if (value.getBitWidth() != intType.getWidth())
145 return emitError() << "type and value bitwidth mismatch: "
146 << intType.getWidth() << " != " << value.getBitWidth();
147
148 return success();
149}
150
151//===----------------------------------------------------------------------===//
152// FPAttr definitions
153//===----------------------------------------------------------------------===//
154
155static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
156 p << value;
157}
158
159static ParseResult parseFloatLiteral(AsmParser &parser,
160 FailureOr<APFloat> &value,
161 CIRFPTypeInterface fpType) {
162
163 APFloat parsedValue(0.0);
164 if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
165 return failure();
166
167 value.emplace(args&: parsedValue);
168 return success();
169}
170
171FPAttr FPAttr::getZero(Type type) {
172 return get(type,
173 APFloat::getZero(
174 mlir::cast<CIRFPTypeInterface>(type).getFloatSemantics()));
175}
176
177LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
178 CIRFPTypeInterface fpType, APFloat value) {
179 if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
180 APFloat::SemanticsToEnum(value.getSemantics()))
181 return emitError() << "floating-point semantics mismatch";
182
183 return success();
184}
185
186//===----------------------------------------------------------------------===//
187// ConstComplexAttr definitions
188//===----------------------------------------------------------------------===//
189
190LogicalResult
191ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
192 cir::ComplexType type, mlir::TypedAttr real,
193 mlir::TypedAttr imag) {
194 mlir::Type elemType = type.getElementType();
195 if (real.getType() != elemType)
196 return emitError()
197 << "type of the real part does not match the complex type";
198
199 if (imag.getType() != elemType)
200 return emitError()
201 << "type of the imaginary part does not match the complex type";
202
203 return success();
204}
205
206//===----------------------------------------------------------------------===//
207// CIR ConstArrayAttr
208//===----------------------------------------------------------------------===//
209
210LogicalResult
211ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type,
212 Attribute elts, int trailingZerosNum) {
213
214 if (!(mlir::isa<ArrayAttr, StringAttr>(elts)))
215 return emitError() << "constant array expects ArrayAttr or StringAttr";
216
217 if (auto strAttr = mlir::dyn_cast<StringAttr>(elts)) {
218 const auto arrayTy = mlir::cast<ArrayType>(type);
219 const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType());
220
221 // TODO: add CIR type for char.
222 if (!intTy || intTy.getWidth() != 8)
223 return emitError()
224 << "constant array element for string literals expects "
225 "!cir.int<u, 8> element type";
226 return success();
227 }
228
229 assert(mlir::isa<ArrayAttr>(elts));
230 const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(elts);
231 const auto arrayTy = mlir::cast<ArrayType>(type);
232
233 // Make sure both number of elements and subelement types match type.
234 if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum)
235 return emitError() << "constant array size should match type size";
236 return success();
237}
238
239Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) {
240 mlir::FailureOr<Type> resultTy;
241 mlir::FailureOr<Attribute> resultVal;
242
243 // Parse literal '<'
244 if (parser.parseLess())
245 return {};
246
247 // Parse variable 'value'
248 resultVal = FieldParser<Attribute>::parse(parser);
249 if (failed(resultVal)) {
250 parser.emitError(
251 parser.getCurrentLocation(),
252 "failed to parse ConstArrayAttr parameter 'value' which is "
253 "to be a `Attribute`");
254 return {};
255 }
256
257 // ArrayAttrrs have per-element type, not the type of the array...
258 if (mlir::isa<ArrayAttr>(*resultVal)) {
259 // Array has implicit type: infer from const array type.
260 if (parser.parseOptionalColon().failed()) {
261 resultTy = type;
262 } else { // Array has explicit type: parse it.
263 resultTy = FieldParser<Type>::parse(parser);
264 if (failed(resultTy)) {
265 parser.emitError(
266 parser.getCurrentLocation(),
267 "failed to parse ConstArrayAttr parameter 'type' which is "
268 "to be a `::mlir::Type`");
269 return {};
270 }
271 }
272 } else {
273 auto ta = mlir::cast<TypedAttr>(*resultVal);
274 resultTy = ta.getType();
275 if (mlir::isa<mlir::NoneType>(*resultTy)) {
276 parser.emitError(parser.getCurrentLocation(),
277 "expected type declaration for string literal");
278 return {};
279 }
280 }
281
282 unsigned zeros = 0;
283 if (parser.parseOptionalComma().succeeded()) {
284 if (parser.parseOptionalKeyword("trailing_zeros").succeeded()) {
285 unsigned typeSize =
286 mlir::cast<cir::ArrayType>(resultTy.value()).getSize();
287 mlir::Attribute elts = resultVal.value();
288 if (auto str = mlir::dyn_cast<mlir::StringAttr>(elts))
289 zeros = typeSize - str.size();
290 else
291 zeros = typeSize - mlir::cast<mlir::ArrayAttr>(elts).size();
292 } else {
293 return {};
294 }
295 }
296
297 // Parse literal '>'
298 if (parser.parseGreater())
299 return {};
300
301 return parser.getChecked<ConstArrayAttr>(
302 parser.getCurrentLocation(), parser.getContext(), resultTy.value(),
303 resultVal.value(), zeros);
304}
305
306void ConstArrayAttr::print(AsmPrinter &printer) const {
307 printer << "<";
308 printer.printStrippedAttrOrType(getElts());
309 if (getTrailingZerosNum())
310 printer << ", trailing_zeros";
311 printer << ">";
312}
313
314//===----------------------------------------------------------------------===//
315// CIR ConstVectorAttr
316//===----------------------------------------------------------------------===//
317
318LogicalResult
319cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError,
320 Type type, ArrayAttr elts) {
321
322 if (!mlir::isa<cir::VectorType>(type))
323 return emitError() << "type of cir::ConstVectorAttr is not a "
324 "cir::VectorType: "
325 << type;
326
327 const auto vecType = mlir::cast<cir::VectorType>(type);
328
329 if (vecType.getSize() != elts.size())
330 return emitError()
331 << "number of constant elements should match vector size";
332
333 // Check if the types of the elements match
334 LogicalResult elementTypeCheck = success();
335 elts.walkImmediateSubElements(
336 [&](Attribute element) {
337 if (elementTypeCheck.failed()) {
338 // An earlier element didn't match
339 return;
340 }
341 auto typedElement = mlir::dyn_cast<TypedAttr>(element);
342 if (!typedElement ||
343 typedElement.getType() != vecType.getElementType()) {
344 elementTypeCheck = failure();
345 emitError() << "constant type should match vector element type";
346 }
347 },
348 [&](Type) {});
349
350 return elementTypeCheck;
351}
352
353//===----------------------------------------------------------------------===//
354// CIR Dialect
355//===----------------------------------------------------------------------===//
356
357void CIRDialect::registerAttributes() {
358 addAttributes<
359#define GET_ATTRDEF_LIST
360#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"
361 >();
362}
363

source code of clang/lib/CIR/Dialect/IR/CIRAttrs.cpp