1 | //===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the attributes in the CIR dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "clang/CIR/Dialect/IR/CIRDialect.h" |
14 | |
15 | #include "mlir/IR/DialectImplementation.h" |
16 | #include "llvm/ADT/TypeSwitch.h" |
17 | |
18 | static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value, |
19 | mlir::Type ty); |
20 | static mlir::ParseResult |
21 | parseFloatLiteral(mlir::AsmParser &parser, |
22 | mlir::FailureOr<llvm::APFloat> &value, |
23 | cir::CIRFPTypeInterface fpType); |
24 | |
25 | static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser, |
26 | mlir::IntegerAttr &value); |
27 | |
28 | static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value); |
29 | |
30 | #define GET_ATTRDEF_CLASSES |
31 | #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc" |
32 | |
33 | using namespace mlir; |
34 | using namespace cir; |
35 | |
36 | //===----------------------------------------------------------------------===// |
37 | // General CIR parsing / printing |
38 | //===----------------------------------------------------------------------===// |
39 | |
40 | Attribute CIRDialect::parseAttribute(DialectAsmParser &parser, |
41 | Type type) const { |
42 | llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
43 | llvm::StringRef mnemonic; |
44 | Attribute genAttr; |
45 | OptionalParseResult parseResult = |
46 | generatedAttributeParser(parser, &mnemonic, type, genAttr); |
47 | if (parseResult.has_value()) |
48 | return genAttr; |
49 | parser.emitError(typeLoc, "unknown attribute in CIR dialect" ); |
50 | return Attribute(); |
51 | } |
52 | |
53 | void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { |
54 | if (failed(generatedAttributePrinter(attr, os))) |
55 | llvm_unreachable("unexpected CIR type kind" ); |
56 | } |
57 | |
58 | //===----------------------------------------------------------------------===// |
59 | // ConstPtrAttr definitions |
60 | //===----------------------------------------------------------------------===// |
61 | |
62 | // TODO(CIR): Consider encoding the null value differently and use conditional |
63 | // assembly format instead of custom parsing/printing. |
64 | static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) { |
65 | |
66 | if (parser.parseOptionalKeyword(keyword: "null" ).succeeded()) { |
67 | value = parser.getBuilder().getI64IntegerAttr(0); |
68 | return success(); |
69 | } |
70 | |
71 | return parser.parseAttribute(value); |
72 | } |
73 | |
74 | static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) { |
75 | if (!value.getInt()) |
76 | p << "null" ; |
77 | else |
78 | p << value; |
79 | } |
80 | |
81 | //===----------------------------------------------------------------------===// |
82 | // IntAttr definitions |
83 | //===----------------------------------------------------------------------===// |
84 | |
85 | Attribute IntAttr::parse(AsmParser &parser, Type odsType) { |
86 | mlir::APInt apValue; |
87 | |
88 | if (!mlir::isa<IntType>(odsType)) |
89 | return {}; |
90 | auto type = mlir::cast<IntType>(odsType); |
91 | |
92 | // Consume the '<' symbol. |
93 | if (parser.parseLess()) |
94 | return {}; |
95 | |
96 | // Fetch arbitrary precision integer value. |
97 | if (type.isSigned()) { |
98 | int64_t value = 0; |
99 | if (parser.parseInteger(value)) { |
100 | parser.emitError(parser.getCurrentLocation(), "expected integer value" ); |
101 | } else { |
102 | apValue = mlir::APInt(type.getWidth(), value, type.isSigned(), |
103 | /*implicitTrunc=*/true); |
104 | if (apValue.getSExtValue() != value) |
105 | parser.emitError(parser.getCurrentLocation(), |
106 | "integer value too large for the given type" ); |
107 | } |
108 | } else { |
109 | uint64_t value = 0; |
110 | if (parser.parseInteger(value)) { |
111 | parser.emitError(parser.getCurrentLocation(), "expected integer value" ); |
112 | } else { |
113 | apValue = mlir::APInt(type.getWidth(), value, type.isSigned(), |
114 | /*implicitTrunc=*/true); |
115 | if (apValue.getZExtValue() != value) |
116 | parser.emitError(parser.getCurrentLocation(), |
117 | "integer value too large for the given type" ); |
118 | } |
119 | } |
120 | |
121 | // Consume the '>' symbol. |
122 | if (parser.parseGreater()) |
123 | return {}; |
124 | |
125 | return IntAttr::get(type, apValue); |
126 | } |
127 | |
128 | void IntAttr::print(AsmPrinter &printer) const { |
129 | auto type = mlir::cast<IntType>(getType()); |
130 | printer << '<'; |
131 | if (type.isSigned()) |
132 | printer << getSInt(); |
133 | else |
134 | printer << getUInt(); |
135 | printer << '>'; |
136 | } |
137 | |
138 | LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
139 | Type type, APInt value) { |
140 | if (!mlir::isa<IntType>(type)) |
141 | return emitError() << "expected 'simple.int' type" ; |
142 | |
143 | auto intType = mlir::cast<IntType>(type); |
144 | if (value.getBitWidth() != intType.getWidth()) |
145 | return emitError() << "type and value bitwidth mismatch: " |
146 | << intType.getWidth() << " != " << value.getBitWidth(); |
147 | |
148 | return success(); |
149 | } |
150 | |
151 | //===----------------------------------------------------------------------===// |
152 | // FPAttr definitions |
153 | //===----------------------------------------------------------------------===// |
154 | |
155 | static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) { |
156 | p << value; |
157 | } |
158 | |
159 | static ParseResult parseFloatLiteral(AsmParser &parser, |
160 | FailureOr<APFloat> &value, |
161 | CIRFPTypeInterface fpType) { |
162 | |
163 | APFloat parsedValue(0.0); |
164 | if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue)) |
165 | return failure(); |
166 | |
167 | value.emplace(args&: parsedValue); |
168 | return success(); |
169 | } |
170 | |
171 | FPAttr FPAttr::getZero(Type type) { |
172 | return get(type, |
173 | APFloat::getZero( |
174 | mlir::cast<CIRFPTypeInterface>(type).getFloatSemantics())); |
175 | } |
176 | |
177 | LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
178 | CIRFPTypeInterface fpType, APFloat value) { |
179 | if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) != |
180 | APFloat::SemanticsToEnum(value.getSemantics())) |
181 | return emitError() << "floating-point semantics mismatch" ; |
182 | |
183 | return success(); |
184 | } |
185 | |
186 | //===----------------------------------------------------------------------===// |
187 | // ConstComplexAttr definitions |
188 | //===----------------------------------------------------------------------===// |
189 | |
190 | LogicalResult |
191 | ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
192 | cir::ComplexType type, mlir::TypedAttr real, |
193 | mlir::TypedAttr imag) { |
194 | mlir::Type elemType = type.getElementType(); |
195 | if (real.getType() != elemType) |
196 | return emitError() |
197 | << "type of the real part does not match the complex type" ; |
198 | |
199 | if (imag.getType() != elemType) |
200 | return emitError() |
201 | << "type of the imaginary part does not match the complex type" ; |
202 | |
203 | return success(); |
204 | } |
205 | |
206 | //===----------------------------------------------------------------------===// |
207 | // CIR ConstArrayAttr |
208 | //===----------------------------------------------------------------------===// |
209 | |
210 | LogicalResult |
211 | ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type, |
212 | Attribute elts, int trailingZerosNum) { |
213 | |
214 | if (!(mlir::isa<ArrayAttr, StringAttr>(elts))) |
215 | return emitError() << "constant array expects ArrayAttr or StringAttr" ; |
216 | |
217 | if (auto strAttr = mlir::dyn_cast<StringAttr>(elts)) { |
218 | const auto arrayTy = mlir::cast<ArrayType>(type); |
219 | const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType()); |
220 | |
221 | // TODO: add CIR type for char. |
222 | if (!intTy || intTy.getWidth() != 8) |
223 | return emitError() |
224 | << "constant array element for string literals expects " |
225 | "!cir.int<u, 8> element type" ; |
226 | return success(); |
227 | } |
228 | |
229 | assert(mlir::isa<ArrayAttr>(elts)); |
230 | const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(elts); |
231 | const auto arrayTy = mlir::cast<ArrayType>(type); |
232 | |
233 | // Make sure both number of elements and subelement types match type. |
234 | if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum) |
235 | return emitError() << "constant array size should match type size" ; |
236 | return success(); |
237 | } |
238 | |
239 | Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) { |
240 | mlir::FailureOr<Type> resultTy; |
241 | mlir::FailureOr<Attribute> resultVal; |
242 | |
243 | // Parse literal '<' |
244 | if (parser.parseLess()) |
245 | return {}; |
246 | |
247 | // Parse variable 'value' |
248 | resultVal = FieldParser<Attribute>::parse(parser); |
249 | if (failed(resultVal)) { |
250 | parser.emitError( |
251 | parser.getCurrentLocation(), |
252 | "failed to parse ConstArrayAttr parameter 'value' which is " |
253 | "to be a `Attribute`" ); |
254 | return {}; |
255 | } |
256 | |
257 | // ArrayAttrrs have per-element type, not the type of the array... |
258 | if (mlir::isa<ArrayAttr>(*resultVal)) { |
259 | // Array has implicit type: infer from const array type. |
260 | if (parser.parseOptionalColon().failed()) { |
261 | resultTy = type; |
262 | } else { // Array has explicit type: parse it. |
263 | resultTy = FieldParser<Type>::parse(parser); |
264 | if (failed(resultTy)) { |
265 | parser.emitError( |
266 | parser.getCurrentLocation(), |
267 | "failed to parse ConstArrayAttr parameter 'type' which is " |
268 | "to be a `::mlir::Type`" ); |
269 | return {}; |
270 | } |
271 | } |
272 | } else { |
273 | auto ta = mlir::cast<TypedAttr>(*resultVal); |
274 | resultTy = ta.getType(); |
275 | if (mlir::isa<mlir::NoneType>(*resultTy)) { |
276 | parser.emitError(parser.getCurrentLocation(), |
277 | "expected type declaration for string literal" ); |
278 | return {}; |
279 | } |
280 | } |
281 | |
282 | unsigned zeros = 0; |
283 | if (parser.parseOptionalComma().succeeded()) { |
284 | if (parser.parseOptionalKeyword("trailing_zeros" ).succeeded()) { |
285 | unsigned typeSize = |
286 | mlir::cast<cir::ArrayType>(resultTy.value()).getSize(); |
287 | mlir::Attribute elts = resultVal.value(); |
288 | if (auto str = mlir::dyn_cast<mlir::StringAttr>(elts)) |
289 | zeros = typeSize - str.size(); |
290 | else |
291 | zeros = typeSize - mlir::cast<mlir::ArrayAttr>(elts).size(); |
292 | } else { |
293 | return {}; |
294 | } |
295 | } |
296 | |
297 | // Parse literal '>' |
298 | if (parser.parseGreater()) |
299 | return {}; |
300 | |
301 | return parser.getChecked<ConstArrayAttr>( |
302 | parser.getCurrentLocation(), parser.getContext(), resultTy.value(), |
303 | resultVal.value(), zeros); |
304 | } |
305 | |
306 | void ConstArrayAttr::print(AsmPrinter &printer) const { |
307 | printer << "<" ; |
308 | printer.printStrippedAttrOrType(getElts()); |
309 | if (getTrailingZerosNum()) |
310 | printer << ", trailing_zeros" ; |
311 | printer << ">" ; |
312 | } |
313 | |
314 | //===----------------------------------------------------------------------===// |
315 | // CIR ConstVectorAttr |
316 | //===----------------------------------------------------------------------===// |
317 | |
318 | LogicalResult |
319 | cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
320 | Type type, ArrayAttr elts) { |
321 | |
322 | if (!mlir::isa<cir::VectorType>(type)) |
323 | return emitError() << "type of cir::ConstVectorAttr is not a " |
324 | "cir::VectorType: " |
325 | << type; |
326 | |
327 | const auto vecType = mlir::cast<cir::VectorType>(type); |
328 | |
329 | if (vecType.getSize() != elts.size()) |
330 | return emitError() |
331 | << "number of constant elements should match vector size" ; |
332 | |
333 | // Check if the types of the elements match |
334 | LogicalResult elementTypeCheck = success(); |
335 | elts.walkImmediateSubElements( |
336 | [&](Attribute element) { |
337 | if (elementTypeCheck.failed()) { |
338 | // An earlier element didn't match |
339 | return; |
340 | } |
341 | auto typedElement = mlir::dyn_cast<TypedAttr>(element); |
342 | if (!typedElement || |
343 | typedElement.getType() != vecType.getElementType()) { |
344 | elementTypeCheck = failure(); |
345 | emitError() << "constant type should match vector element type" ; |
346 | } |
347 | }, |
348 | [&](Type) {}); |
349 | |
350 | return elementTypeCheck; |
351 | } |
352 | |
353 | //===----------------------------------------------------------------------===// |
354 | // CIR Dialect |
355 | //===----------------------------------------------------------------------===// |
356 | |
357 | void CIRDialect::registerAttributes() { |
358 | addAttributes< |
359 | #define GET_ATTRDEF_LIST |
360 | #include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc" |
361 | >(); |
362 | } |
363 | |