1//===-- FIRType.cpp -------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/Dialect/FIRType.h"
14#include "flang/Common/ISO_Fortran_binding_wrapper.h"
15#include "flang/Optimizer/Builder/Todo.h"
16#include "flang/Optimizer/Dialect/FIRDialect.h"
17#include "flang/Optimizer/Dialect/Support/KindMapping.h"
18#include "flang/Tools/PointerModels.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/ErrorHandling.h"
28
29#define GET_TYPEDEF_CLASSES
30#include "flang/Optimizer/Dialect/FIROpsTypes.cpp.inc"
31
32using namespace fir;
33
34namespace {
35
36template <typename TYPE>
37TYPE parseIntSingleton(mlir::AsmParser &parser) {
38 int kind = 0;
39 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseGreater())
40 return {};
41 return TYPE::get(parser.getContext(), kind);
42}
43
44template <typename TYPE>
45TYPE parseKindSingleton(mlir::AsmParser &parser) {
46 return parseIntSingleton<TYPE>(parser);
47}
48
49template <typename TYPE>
50TYPE parseRankSingleton(mlir::AsmParser &parser) {
51 return parseIntSingleton<TYPE>(parser);
52}
53
54template <typename TYPE>
55TYPE parseTypeSingleton(mlir::AsmParser &parser) {
56 mlir::Type ty;
57 if (parser.parseLess() || parser.parseType(result&: ty) || parser.parseGreater())
58 return {};
59 return TYPE::get(ty);
60}
61
62/// Is `ty` a standard or FIR integer type?
63static bool isaIntegerType(mlir::Type ty) {
64 // TODO: why aren't we using isa_integer? investigatation required.
65 return mlir::isa<mlir::IntegerType, fir::IntegerType>(ty);
66}
67
68bool verifyRecordMemberType(mlir::Type ty) {
69 return !mlir::isa<BoxCharType, ShapeType, ShapeShiftType, ShiftType,
70 SliceType, FieldType, LenType, ReferenceType, TypeDescType>(
71 ty);
72}
73
74bool verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1,
75 llvm::ArrayRef<RecordType::TypePair> a2) {
76 // FIXME: do we need to allow for any variance here?
77 return a1 == a2;
78}
79
80static llvm::StringRef getVolatileKeyword() { return "volatile"; }
81
82static mlir::ParseResult parseOptionalCommaAndKeyword(mlir::AsmParser &parser,
83 mlir::StringRef keyword,
84 bool &parsedKeyword) {
85 if (!parser.parseOptionalComma()) {
86 if (parser.parseKeyword(keyword))
87 return mlir::failure();
88 parsedKeyword = true;
89 return mlir::success();
90 }
91 parsedKeyword = false;
92 return mlir::success();
93}
94
95RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
96 llvm::ArrayRef<RecordType::TypePair> lenPList,
97 llvm::ArrayRef<RecordType::TypePair> typeList) {
98 auto loc = parser.getNameLoc();
99 if (!verifySameLists(derivedTy.getLenParamList(), lenPList) ||
100 !verifySameLists(derivedTy.getTypeList(), typeList)) {
101 parser.emitError(loc, message: "cannot redefine record type members");
102 return {};
103 }
104 for (auto &p : lenPList)
105 if (!isaIntegerType(p.second)) {
106 parser.emitError(loc, "LEN parameter must be integral type");
107 return {};
108 }
109 for (auto &p : typeList)
110 if (!verifyRecordMemberType(p.second)) {
111 parser.emitError(loc, "field parameter has invalid type");
112 return {};
113 }
114 llvm::StringSet<> uniq;
115 for (auto &p : lenPList)
116 if (!uniq.insert(p.first).second) {
117 parser.emitError(loc, "LEN parameter cannot have duplicate name");
118 return {};
119 }
120 for (auto &p : typeList)
121 if (!uniq.insert(p.first).second) {
122 parser.emitError(loc, "field cannot have duplicate name");
123 return {};
124 }
125 return derivedTy;
126}
127
128} // namespace
129
130// Implementation of the thin interface from dialect to type parser
131
132mlir::Type fir::parseFirType(FIROpsDialect *dialect,
133 mlir::DialectAsmParser &parser) {
134 mlir::StringRef typeTag;
135 mlir::Type genType;
136 auto parseResult = generatedTypeParser(parser, &typeTag, genType);
137 if (parseResult.has_value())
138 return genType;
139 parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;
140 return {};
141}
142
143namespace fir {
144namespace detail {
145
146// Type storage classes
147
148/// Derived type storage
149struct RecordTypeStorage : public mlir::TypeStorage {
150 using KeyTy = llvm::StringRef;
151
152 static unsigned hashKey(const KeyTy &key) {
153 return llvm::hash_combine(args: key.str());
154 }
155
156 bool operator==(const KeyTy &key) const { return key == getName(); }
157
158 static RecordTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
159 const KeyTy &key) {
160 auto *storage = allocator.allocate<RecordTypeStorage>();
161 return new (storage) RecordTypeStorage{key};
162 }
163
164 llvm::StringRef getName() const { return name; }
165
166 void setLenParamList(llvm::ArrayRef<RecordType::TypePair> list) {
167 lens = list;
168 }
169 llvm::ArrayRef<RecordType::TypePair> getLenParamList() const { return lens; }
170
171 void setTypeList(llvm::ArrayRef<RecordType::TypePair> list) { types = list; }
172 llvm::ArrayRef<RecordType::TypePair> getTypeList() const { return types; }
173
174 bool isFinalized() const { return finalized; }
175 void finalize(llvm::ArrayRef<RecordType::TypePair> lenParamList,
176 llvm::ArrayRef<RecordType::TypePair> typeList) {
177 if (finalized)
178 return;
179 finalized = true;
180 setLenParamList(lenParamList);
181 setTypeList(typeList);
182 }
183
184 bool isPacked() const { return packed; }
185 void pack(bool p) { packed = p; }
186
187protected:
188 std::string name;
189 bool finalized;
190 bool packed;
191 std::vector<RecordType::TypePair> lens;
192 std::vector<RecordType::TypePair> types;
193
194private:
195 RecordTypeStorage() = delete;
196 explicit RecordTypeStorage(llvm::StringRef name)
197 : name{name}, finalized{false}, packed{false} {}
198};
199
200} // namespace detail
201
202template <typename A, typename B>
203bool inbounds(A v, B lb, B ub) {
204 return v >= lb && v < ub;
205}
206
207bool isa_fir_type(mlir::Type t) {
208 return llvm::isa<FIROpsDialect>(t.getDialect());
209}
210
211bool isa_std_type(mlir::Type t) {
212 return llvm::isa<mlir::BuiltinDialect>(Val: t.getDialect());
213}
214
215bool isa_fir_or_std_type(mlir::Type t) {
216 if (auto funcType = mlir::dyn_cast<mlir::FunctionType>(t))
217 return llvm::all_of(funcType.getInputs(), isa_fir_or_std_type) &&
218 llvm::all_of(funcType.getResults(), isa_fir_or_std_type);
219 return isa_fir_type(t) || isa_std_type(t);
220}
221
222mlir::Type getDerivedType(mlir::Type ty) {
223 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty)
224 .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto p) {
225 if (auto seq = mlir::dyn_cast<fir::SequenceType>(p.getEleTy()))
226 return seq.getEleTy();
227 return p.getEleTy();
228 })
229 .Case<fir::BaseBoxType>(
230 [](auto p) { return getDerivedType(p.getEleTy()); })
231 .Default([](mlir::Type t) { return t; });
232}
233
234mlir::Type updateTypeWithVolatility(mlir::Type type, bool isVolatile) {
235 // If we already have the volatility we asked for, return the type unchanged.
236 if (fir::isa_volatile_type(type) == isVolatile)
237 return type;
238 return mlir::TypeSwitch<mlir::Type, mlir::Type>(type)
239 .Case<fir::BoxType, fir::ClassType, fir::ReferenceType>(
240 [&](auto ty) -> mlir::Type {
241 using TYPE = decltype(ty);
242 return TYPE::get(ty.getEleTy(), isVolatile);
243 })
244 .Default([&](mlir::Type t) -> mlir::Type { return t; });
245}
246
247mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
248 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
249 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
250 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
251 .Default([](mlir::Type) { return mlir::Type{}; });
252}
253
254mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
255 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
256 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
257 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
258 .Case<fir::BaseBoxType, fir::BoxCharType>(
259 [](auto p) { return unwrapRefType(p.getEleTy()); })
260 .Default([](mlir::Type) { return mlir::Type{}; });
261}
262
263static bool hasDynamicSize(fir::RecordType recTy) {
264 for (auto field : recTy.getTypeList()) {
265 if (auto arr = mlir::dyn_cast<fir::SequenceType>(field.second)) {
266 if (sequenceWithNonConstantShape(arr))
267 return true;
268 } else if (characterWithDynamicLen(field.second)) {
269 return true;
270 } else if (auto rec = mlir::dyn_cast<fir::RecordType>(field.second)) {
271 if (hasDynamicSize(rec))
272 return true;
273 }
274 }
275 return false;
276}
277
278bool hasDynamicSize(mlir::Type t) {
279 if (auto arr = mlir::dyn_cast<fir::SequenceType>(t)) {
280 if (sequenceWithNonConstantShape(arr))
281 return true;
282 t = arr.getEleTy();
283 }
284 if (characterWithDynamicLen(t))
285 return true;
286 if (auto rec = mlir::dyn_cast<fir::RecordType>(t))
287 return hasDynamicSize(rec);
288 return false;
289}
290
291mlir::Type extractSequenceType(mlir::Type ty) {
292 if (mlir::isa<fir::SequenceType>(ty))
293 return ty;
294 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
295 return extractSequenceType(boxTy.getEleTy());
296 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
297 return extractSequenceType(heapTy.getEleTy());
298 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
299 return extractSequenceType(ptrTy.getEleTy());
300 return mlir::Type{};
301}
302
303bool isPointerType(mlir::Type ty) {
304 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
305 ty = refTy;
306 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
307 return mlir::isa<fir::PointerType>(boxTy.getEleTy());
308 return false;
309}
310
311bool isAllocatableType(mlir::Type ty) {
312 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
313 ty = refTy;
314 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
315 return mlir::isa<fir::HeapType>(boxTy.getEleTy());
316 return false;
317}
318
319bool isBoxNone(mlir::Type ty) {
320 if (auto box = mlir::dyn_cast<fir::BoxType>(ty))
321 return mlir::isa<mlir::NoneType>(box.getEleTy());
322 return false;
323}
324
325bool isBoxedRecordType(mlir::Type ty) {
326 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
327 ty = refTy;
328 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
329 if (mlir::isa<fir::RecordType>(boxTy.getEleTy()))
330 return true;
331 mlir::Type innerType = boxTy.unwrapInnerType();
332 return innerType && mlir::isa<fir::RecordType>(innerType);
333 }
334 return false;
335}
336
337bool isScalarBoxedRecordType(mlir::Type ty) {
338 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
339 ty = refTy;
340 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
341 if (mlir::isa<fir::RecordType>(boxTy.getEleTy()))
342 return true;
343 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
344 return mlir::isa<fir::RecordType>(heapTy.getEleTy());
345 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
346 return mlir::isa<fir::RecordType>(ptrTy.getEleTy());
347 }
348 return false;
349}
350
351bool isAssumedType(mlir::Type ty) {
352 // Rule out CLASS(*) which are `fir.class<[fir.array] none>`.
353 if (mlir::isa<fir::ClassType>(ty))
354 return false;
355 mlir::Type valueType = fir::unwrapPassByRefType(fir::unwrapRefType(ty));
356 // Refuse raw `none` or `fir.array<none>` since assumed type
357 // should be in memory variables.
358 if (valueType == ty)
359 return false;
360 mlir::Type inner = fir::unwrapSequenceType(valueType);
361 return mlir::isa<mlir::NoneType>(Val: inner);
362}
363
364bool isAssumedShape(mlir::Type ty) {
365 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty))
366 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
367 return seqTy.hasDynamicExtents();
368 return false;
369}
370
371bool isAllocatableOrPointerArray(mlir::Type ty) {
372 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
373 ty = refTy;
374 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
375 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
376 return mlir::isa<fir::SequenceType>(heapTy.getEleTy());
377 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
378 return mlir::isa<fir::SequenceType>(ptrTy.getEleTy());
379 }
380 return false;
381}
382
383bool isTypeWithDescriptor(mlir::Type ty) {
384 if (mlir::isa<fir::BaseBoxType>(unwrapRefType(ty)))
385 return true;
386 return false;
387}
388
389bool isPolymorphicType(mlir::Type ty) {
390 // CLASS(T) or CLASS(*)
391 if (mlir::isa<fir::ClassType>(fir::unwrapRefType(ty)))
392 return true;
393 // assumed type are polymorphic.
394 return isAssumedType(ty);
395}
396
397bool isUnlimitedPolymorphicType(mlir::Type ty) {
398 // CLASS(*)
399 if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) {
400 if (mlir::isa<mlir::NoneType>(clTy.getEleTy()))
401 return true;
402 mlir::Type innerType = clTy.unwrapInnerType();
403 return innerType && mlir::isa<mlir::NoneType>(Val: innerType);
404 }
405 // TYPE(*)
406 return isAssumedType(ty);
407}
408
409mlir::Type unwrapInnerType(mlir::Type ty) {
410 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty)
411 .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto t) {
412 mlir::Type eleTy = t.getEleTy();
413 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
414 return seqTy.getEleTy();
415 return eleTy;
416 })
417 .Case<fir::RecordType>([](auto t) { return t; })
418 .Default([](mlir::Type) { return mlir::Type{}; });
419}
420
421bool isRecordWithAllocatableMember(mlir::Type ty) {
422 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty))
423 for (auto [field, memTy] : recTy.getTypeList()) {
424 if (fir::isAllocatableType(memTy))
425 return true;
426 // A record type cannot recursively include itself as a direct member.
427 // There must be an intervening `ptr` type, so recursion is safe here.
428 if (mlir::isa<fir::RecordType>(memTy) &&
429 isRecordWithAllocatableMember(memTy))
430 return true;
431 }
432 return false;
433}
434
435bool isRecordWithDescriptorMember(mlir::Type ty) {
436 ty = unwrapSequenceType(ty);
437 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty))
438 for (auto [field, memTy] : recTy.getTypeList()) {
439 memTy = unwrapSequenceType(memTy);
440 if (mlir::isa<fir::BaseBoxType>(memTy))
441 return true;
442 if (mlir::isa<fir::RecordType>(memTy) &&
443 isRecordWithDescriptorMember(memTy))
444 return true;
445 }
446 return false;
447}
448
449mlir::Type unwrapAllRefAndSeqType(mlir::Type ty) {
450 while (true) {
451 mlir::Type nt = unwrapSequenceType(unwrapRefType(ty));
452 if (auto vecTy = mlir::dyn_cast<fir::VectorType>(nt))
453 nt = vecTy.getEleTy();
454 if (nt == ty)
455 return ty;
456 ty = nt;
457 }
458}
459
460mlir::Type getFortranElementType(mlir::Type ty) {
461 return fir::unwrapSequenceType(
462 fir::unwrapPassByRefType(fir::unwrapRefType(ty)));
463}
464
465mlir::Type unwrapSeqOrBoxedSeqType(mlir::Type ty) {
466 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
467 return seqTy.getEleTy();
468 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
469 auto eleTy = unwrapRefType(boxTy.getEleTy());
470 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
471 return seqTy.getEleTy();
472 }
473 return ty;
474}
475
476unsigned getBoxRank(mlir::Type boxTy) {
477 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(t: boxTy);
478 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
479 return seqTy.getDimension();
480 return 0;
481}
482
483/// Return the ISO_C_BINDING intrinsic module value of type \p ty.
484int getTypeCode(mlir::Type ty, const fir::KindMapping &kindMap) {
485 if (mlir::IntegerType intTy = mlir::dyn_cast<mlir::IntegerType>(ty)) {
486 if (intTy.isUnsigned()) {
487 switch (intTy.getWidth()) {
488 case 8:
489 return CFI_type_uint8_t;
490 case 16:
491 return CFI_type_uint16_t;
492 case 32:
493 return CFI_type_uint32_t;
494 case 64:
495 return CFI_type_uint64_t;
496 case 128:
497 return CFI_type_uint128_t;
498 }
499 llvm_unreachable("unsupported integer type");
500 } else {
501 switch (intTy.getWidth()) {
502 case 8:
503 return CFI_type_int8_t;
504 case 16:
505 return CFI_type_int16_t;
506 case 32:
507 return CFI_type_int32_t;
508 case 64:
509 return CFI_type_int64_t;
510 case 128:
511 return CFI_type_int128_t;
512 }
513 llvm_unreachable("unsupported integer type");
514 }
515 }
516 if (fir::LogicalType logicalTy = mlir::dyn_cast<fir::LogicalType>(ty)) {
517 switch (kindMap.getLogicalBitsize(logicalTy.getFKind())) {
518 case 8:
519 return CFI_type_Bool;
520 case 16:
521 return CFI_type_int_least16_t;
522 case 32:
523 return CFI_type_int_least32_t;
524 case 64:
525 return CFI_type_int_least64_t;
526 }
527 llvm_unreachable("unsupported logical type");
528 }
529 if (mlir::FloatType floatTy = mlir::dyn_cast<mlir::FloatType>(ty)) {
530 switch (floatTy.getWidth()) {
531 case 16:
532 return floatTy.isBF16() ? CFI_type_bfloat : CFI_type_half_float;
533 case 32:
534 return CFI_type_float;
535 case 64:
536 return CFI_type_double;
537 case 80:
538 return CFI_type_extended_double;
539 case 128:
540 return CFI_type_float128;
541 }
542 llvm_unreachable("unsupported real type");
543 }
544 if (mlir::ComplexType complexTy = mlir::dyn_cast<mlir::ComplexType>(ty)) {
545 mlir::FloatType floatTy =
546 mlir::cast<mlir::FloatType>(complexTy.getElementType());
547 if (floatTy.isBF16())
548 return CFI_type_bfloat_Complex;
549 switch (floatTy.getWidth()) {
550 case 16:
551 return CFI_type_half_float_Complex;
552 case 32:
553 return CFI_type_float_Complex;
554 case 64:
555 return CFI_type_double_Complex;
556 case 80:
557 return CFI_type_extended_double_Complex;
558 case 128:
559 return CFI_type_float128_Complex;
560 }
561 llvm_unreachable("unsupported complex size");
562 }
563 if (fir::CharacterType charTy = mlir::dyn_cast<fir::CharacterType>(ty)) {
564 switch (kindMap.getCharacterBitsize(charTy.getFKind())) {
565 case 8:
566 return CFI_type_char;
567 case 16:
568 return CFI_type_char16_t;
569 case 32:
570 return CFI_type_char32_t;
571 }
572 llvm_unreachable("unsupported character type");
573 }
574 if (fir::isa_ref_type(ty))
575 return CFI_type_cptr;
576 if (mlir::isa<fir::RecordType>(ty))
577 return CFI_type_struct;
578 llvm_unreachable("unsupported type");
579}
580
581std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap,
582 llvm::StringRef prefix) {
583 std::string buf = prefix.str();
584 llvm::raw_string_ostream name{buf};
585 if (!prefix.empty())
586 name << "_";
587
588 std::function<void(mlir::Type)> appendTypeName = [&](mlir::Type ty) {
589 while (ty) {
590 if (fir::isa_trivial(ty)) {
591 if (mlir::isa<mlir::IndexType>(Val: ty)) {
592 name << "idx";
593 } else if (ty.isIntOrIndex()) {
594 name << 'i' << ty.getIntOrFloatBitWidth();
595 } else if (mlir::isa<mlir::FloatType>(Val: ty)) {
596 name << 'f' << ty.getIntOrFloatBitWidth();
597 } else if (auto cplxTy =
598 mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) {
599 name << 'z';
600 auto floatTy = mlir::cast<mlir::FloatType>(cplxTy.getElementType());
601 name << floatTy.getWidth();
602 } else if (auto logTy = mlir::dyn_cast_or_null<fir::LogicalType>(ty)) {
603 name << 'l' << kindMap.getLogicalBitsize(logTy.getFKind());
604 } else {
605 llvm::report_fatal_error(reason: "unsupported type");
606 }
607 break;
608 } else if (mlir::isa<mlir::NoneType>(Val: ty)) {
609 name << "none";
610 break;
611 } else if (auto charTy = mlir::dyn_cast_or_null<fir::CharacterType>(ty)) {
612 name << 'c' << kindMap.getCharacterBitsize(charTy.getFKind());
613 if (charTy.getLen() == fir::CharacterType::unknownLen())
614 name << "xU";
615 else if (charTy.getLen() != fir::CharacterType::singleton())
616 name << "x" << charTy.getLen();
617 break;
618 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
619 for (auto extent : seqTy.getShape()) {
620 if (extent == fir::SequenceType::getUnknownExtent())
621 name << "Ux";
622 else
623 name << extent << 'x';
624 }
625 ty = seqTy.getEleTy();
626 } else if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
627 name << "ref_";
628 ty = refTy.getEleTy();
629 } else if (auto ptrTy = mlir::dyn_cast_or_null<fir::PointerType>(ty)) {
630 name << "ptr_";
631 ty = ptrTy.getEleTy();
632 } else if (auto ptrTy =
633 mlir::dyn_cast_or_null<fir::LLVMPointerType>(ty)) {
634 name << "llvmptr_";
635 ty = ptrTy.getEleTy();
636 } else if (auto heapTy = mlir::dyn_cast_or_null<fir::HeapType>(ty)) {
637 name << "heap_";
638 ty = heapTy.getEleTy();
639 } else if (auto classTy = mlir::dyn_cast_or_null<fir::ClassType>(ty)) {
640 name << "class_";
641 ty = classTy.getEleTy();
642 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BoxType>(ty)) {
643 name << "box_";
644 ty = boxTy.getEleTy();
645 } else if (auto boxcharTy =
646 mlir::dyn_cast_or_null<fir::BoxCharType>(ty)) {
647 name << "boxchar_";
648 ty = boxcharTy.getEleTy();
649 } else if (auto boxprocTy =
650 mlir::dyn_cast_or_null<fir::BoxProcType>(ty)) {
651 name << "boxproc_";
652 auto procTy = mlir::dyn_cast<mlir::FunctionType>(boxprocTy.getEleTy());
653 assert(procTy.getNumResults() <= 1 &&
654 "function type with more than one result");
655 for (const auto &result : procTy.getResults())
656 appendTypeName(result);
657 name << "_args";
658 for (const auto &arg : procTy.getInputs()) {
659 name << '_';
660 appendTypeName(arg);
661 }
662 break;
663 } else if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(ty)) {
664 name << "rec_" << recTy.getName();
665 break;
666 } else {
667 llvm::report_fatal_error(reason: "unsupported type");
668 }
669 }
670 };
671
672 appendTypeName(ty);
673 return buf;
674}
675
676mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
677 bool turnBoxIntoClass) {
678 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
679 .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
680 return fir::SequenceType::get(seqTy.getShape(), newElementType);
681 })
682 .Case<fir::ReferenceType, fir::ClassType>([&](auto t) -> mlir::Type {
683 using FIRT = decltype(t);
684 auto newEleTy =
685 changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass);
686 return FIRT::get(newEleTy, t.isVolatile());
687 })
688 .Case<fir::PointerType, fir::HeapType>([&](auto t) -> mlir::Type {
689 using FIRT = decltype(t);
690 return FIRT::get(
691 changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass));
692 })
693 .Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
694 mlir::Type newInnerType =
695 changeElementType(t.getEleTy(), newElementType, false);
696 if (turnBoxIntoClass)
697 return fir::ClassType::get(newInnerType, t.isVolatile());
698 return fir::BoxType::get(newInnerType, t.isVolatile());
699 })
700 .Default([&](mlir::Type t) -> mlir::Type {
701 assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
702 llvm::isa<mlir::NoneType>(t)) &&
703 "unexpected FIR leaf type");
704 return newElementType;
705 });
706}
707
708} // namespace fir
709
710namespace {
711
712static llvm::SmallPtrSet<detail::RecordTypeStorage const *, 4>
713 recordTypeVisited;
714
715} // namespace
716
717void fir::verifyIntegralType(mlir::Type type) {
718 if (isaIntegerType(ty: type) || mlir::isa<mlir::IndexType>(Val: type))
719 return;
720 llvm::report_fatal_error(reason: "expected integral type");
721}
722
723void fir::printFirType(FIROpsDialect *, mlir::Type ty,
724 mlir::DialectAsmPrinter &p) {
725 if (mlir::failed(Result: generatedTypePrinter(ty, p)))
726 llvm::report_fatal_error(reason: "unknown type to print");
727}
728
729bool fir::isa_unknown_size_box(mlir::Type t) {
730 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(t)) {
731 auto valueType = fir::unwrapPassByRefType(boxTy);
732 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(valueType))
733 if (seqTy.hasUnknownShape())
734 return true;
735 }
736 return false;
737}
738
739bool fir::isa_volatile_type(mlir::Type t) {
740 return llvm::TypeSwitch<mlir::Type, bool>(t)
741 .Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(
742 [](auto t) { return t.isVolatile(); })
743 .Default([](mlir::Type) { return false; });
744}
745
746//===----------------------------------------------------------------------===//
747// BoxProcType
748//===----------------------------------------------------------------------===//
749
750// `boxproc` `<` return-type `>`
751mlir::Type BoxProcType::parse(mlir::AsmParser &parser) {
752 mlir::Type ty;
753 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
754 return {};
755 return get(parser.getContext(), ty);
756}
757
758void fir::BoxProcType::print(mlir::AsmPrinter &printer) const {
759 printer << "<" << getEleTy() << '>';
760}
761
762llvm::LogicalResult
763BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
764 mlir::Type eleTy) {
765 if (mlir::isa<mlir::FunctionType>(eleTy))
766 return mlir::success();
767 if (auto refTy = mlir::dyn_cast<ReferenceType>(eleTy))
768 if (mlir::isa<mlir::FunctionType>(refTy))
769 return mlir::success();
770 return emitError() << "invalid type for boxproc" << eleTy << '\n';
771}
772
773static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) {
774 return mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
775 SliceType, FieldType, LenType, HeapType, PointerType,
776 ReferenceType, TypeDescType>(eleTy);
777}
778
779//===----------------------------------------------------------------------===//
780// BoxType
781//===----------------------------------------------------------------------===//
782
783// `box` `<` type (`, volatile` $volatile^)? `>`
784mlir::Type fir::BoxType::parse(mlir::AsmParser &parser) {
785 mlir::Type eleTy;
786 auto location = parser.getCurrentLocation();
787 auto *context = parser.getContext();
788 bool isVolatile = false;
789 if (parser.parseLess() || parser.parseType(eleTy))
790 return {};
791 if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
792 return {};
793 if (parser.parseGreater())
794 return {};
795 return parser.getChecked<fir::BoxType>(location, context, eleTy, isVolatile);
796}
797
798void fir::BoxType::print(mlir::AsmPrinter &printer) const {
799 printer << "<" << getEleTy();
800 if (isVolatile())
801 printer << ", " << getVolatileKeyword();
802 printer << '>';
803}
804
805llvm::LogicalResult
806fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
807 mlir::Type eleTy, bool isVolatile) {
808 if (mlir::isa<fir::BaseBoxType>(eleTy))
809 return emitError() << "invalid element type\n";
810 // TODO
811 return mlir::success();
812}
813
814//===----------------------------------------------------------------------===//
815// BoxCharType
816//===----------------------------------------------------------------------===//
817
818mlir::Type fir::BoxCharType::parse(mlir::AsmParser &parser) {
819 return parseKindSingleton<fir::BoxCharType>(parser);
820}
821
822void fir::BoxCharType::print(mlir::AsmPrinter &printer) const {
823 printer << "<" << getKind() << ">";
824}
825
826CharacterType
827fir::BoxCharType::getElementType(mlir::MLIRContext *context) const {
828 return CharacterType::getUnknownLen(context, getKind());
829}
830
831CharacterType fir::BoxCharType::getEleTy() const {
832 return getElementType(getContext());
833}
834
835//===----------------------------------------------------------------------===//
836// CharacterType
837//===----------------------------------------------------------------------===//
838
839// `char` `<` kind [`,` `len`] `>`
840mlir::Type fir::CharacterType::parse(mlir::AsmParser &parser) {
841 int kind = 0;
842 if (parser.parseLess() || parser.parseInteger(kind))
843 return {};
844 CharacterType::LenType len = 1;
845 if (mlir::succeeded(parser.parseOptionalComma())) {
846 if (mlir::succeeded(parser.parseOptionalQuestion())) {
847 len = fir::CharacterType::unknownLen();
848 } else if (!mlir::succeeded(parser.parseInteger(len))) {
849 return {};
850 }
851 }
852 if (parser.parseGreater())
853 return {};
854 return get(parser.getContext(), kind, len);
855}
856
857void fir::CharacterType::print(mlir::AsmPrinter &printer) const {
858 printer << "<" << getFKind();
859 auto len = getLen();
860 if (len != fir::CharacterType::singleton()) {
861 printer << ',';
862 if (len == fir::CharacterType::unknownLen())
863 printer << '?';
864 else
865 printer << len;
866 }
867 printer << '>';
868}
869
870//===----------------------------------------------------------------------===//
871// ClassType
872//===----------------------------------------------------------------------===//
873
874// `class` `<` type (`, volatile` $volatile^)? `>`
875mlir::Type fir::ClassType::parse(mlir::AsmParser &parser) {
876 mlir::Type eleTy;
877 auto location = parser.getCurrentLocation();
878 auto *context = parser.getContext();
879 bool isVolatile = false;
880 if (parser.parseLess() || parser.parseType(eleTy))
881 return {};
882 if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
883 return {};
884 if (parser.parseGreater())
885 return {};
886 return parser.getChecked<fir::ClassType>(location, context, eleTy,
887 isVolatile);
888}
889
890void fir::ClassType::print(mlir::AsmPrinter &printer) const {
891 printer << "<" << getEleTy();
892 if (isVolatile())
893 printer << ", " << getVolatileKeyword();
894 printer << '>';
895}
896
897llvm::LogicalResult
898fir::ClassType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
899 mlir::Type eleTy, bool isVolatile) {
900 if (mlir::isa<fir::RecordType, fir::SequenceType, fir::HeapType,
901 fir::PointerType, mlir::NoneType, mlir::IntegerType,
902 mlir::FloatType, fir::CharacterType, fir::LogicalType,
903 mlir::ComplexType>(eleTy))
904 return mlir::success();
905 return emitError() << "invalid element type\n";
906}
907
908//===----------------------------------------------------------------------===//
909// HeapType
910//===----------------------------------------------------------------------===//
911
912// `heap` `<` type `>`
913mlir::Type fir::HeapType::parse(mlir::AsmParser &parser) {
914 return parseTypeSingleton<HeapType>(parser);
915}
916
917void fir::HeapType::print(mlir::AsmPrinter &printer) const {
918 printer << "<" << getEleTy() << '>';
919}
920
921llvm::LogicalResult
922fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
923 mlir::Type eleTy) {
924 if (cannotBePointerOrHeapElementType(eleTy))
925 return emitError() << "cannot build a heap pointer to type: " << eleTy
926 << '\n';
927 return mlir::success();
928}
929
930//===----------------------------------------------------------------------===//
931// IntegerType
932//===----------------------------------------------------------------------===//
933
934// `int` `<` kind `>`
935mlir::Type fir::IntegerType::parse(mlir::AsmParser &parser) {
936 return parseKindSingleton<fir::IntegerType>(parser);
937}
938
939void fir::IntegerType::print(mlir::AsmPrinter &printer) const {
940 printer << "<" << getFKind() << '>';
941}
942
943//===----------------------------------------------------------------------===//
944// UnsignedType
945//===----------------------------------------------------------------------===//
946
947// `unsigned` `<` kind `>`
948mlir::Type fir::UnsignedType::parse(mlir::AsmParser &parser) {
949 return parseKindSingleton<fir::UnsignedType>(parser);
950}
951
952void fir::UnsignedType::print(mlir::AsmPrinter &printer) const {
953 printer << "<" << getFKind() << '>';
954}
955
956//===----------------------------------------------------------------------===//
957// LogicalType
958//===----------------------------------------------------------------------===//
959
960// `logical` `<` kind `>`
961mlir::Type fir::LogicalType::parse(mlir::AsmParser &parser) {
962 return parseKindSingleton<fir::LogicalType>(parser);
963}
964
965void fir::LogicalType::print(mlir::AsmPrinter &printer) const {
966 printer << "<" << getFKind() << '>';
967}
968
969//===----------------------------------------------------------------------===//
970// PointerType
971//===----------------------------------------------------------------------===//
972
973// `ptr` `<` type `>`
974mlir::Type fir::PointerType::parse(mlir::AsmParser &parser) {
975 return parseTypeSingleton<fir::PointerType>(parser);
976}
977
978void fir::PointerType::print(mlir::AsmPrinter &printer) const {
979 printer << "<" << getEleTy() << '>';
980}
981
982llvm::LogicalResult fir::PointerType::verify(
983 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
984 mlir::Type eleTy) {
985 if (cannotBePointerOrHeapElementType(eleTy))
986 return emitError() << "cannot build a pointer to type: " << eleTy << '\n';
987 return mlir::success();
988}
989
990//===----------------------------------------------------------------------===//
991// RecordType
992//===----------------------------------------------------------------------===//
993
994// Fortran derived type
995// unpacked:
996// `type` `<` name
997// (`(` id `:` type (`,` id `:` type)* `)`)?
998// (`{` id `:` type (`,` id `:` type)* `}`)? '>'
999// packed:
1000// `type` `<` name
1001// (`(` id `:` type (`,` id `:` type)* `)`)?
1002// (`<{` id `:` type (`,` id `:` type)* `}>`)? '>'
1003mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
1004 llvm::StringRef name;
1005 if (parser.parseLess() || parser.parseKeyword(&name))
1006 return {};
1007 RecordType result = RecordType::get(parser.getContext(), name);
1008
1009 RecordType::TypeList lenParamList;
1010 if (!parser.parseOptionalLParen()) {
1011 while (true) {
1012 llvm::StringRef lenparam;
1013 mlir::Type intTy;
1014 if (parser.parseKeyword(&lenparam) || parser.parseColon() ||
1015 parser.parseType(intTy)) {
1016 parser.emitError(parser.getNameLoc(), "expected LEN parameter list");
1017 return {};
1018 }
1019 lenParamList.emplace_back(lenparam, intTy);
1020 if (parser.parseOptionalComma())
1021 break;
1022 }
1023 if (parser.parseRParen())
1024 return {};
1025 }
1026
1027 RecordType::TypeList typeList;
1028 if (!parser.parseOptionalLess()) {
1029 result.pack(true);
1030 }
1031
1032 if (!parser.parseOptionalLBrace()) {
1033 while (true) {
1034 llvm::StringRef field;
1035 mlir::Type fldTy;
1036 if (parser.parseKeyword(&field) || parser.parseColon() ||
1037 parser.parseType(fldTy)) {
1038 parser.emitError(parser.getNameLoc(), "expected field type list");
1039 return {};
1040 }
1041 typeList.emplace_back(field, fldTy);
1042 if (parser.parseOptionalComma())
1043 break;
1044 }
1045 if (parser.parseOptionalGreater()) {
1046 if (parser.parseRBrace())
1047 return {};
1048 }
1049 }
1050
1051 if (parser.parseGreater())
1052 return {};
1053
1054 if (lenParamList.empty() && typeList.empty())
1055 return result;
1056
1057 result.finalize(lenParamList, typeList);
1058 return verifyDerived(parser, result, lenParamList, typeList);
1059}
1060
1061void fir::RecordType::print(mlir::AsmPrinter &printer) const {
1062 printer << "<" << getName();
1063 if (!recordTypeVisited.count(uniqueKey())) {
1064 recordTypeVisited.insert(uniqueKey());
1065 if (getLenParamList().size()) {
1066 char ch = '(';
1067 for (auto p : getLenParamList()) {
1068 printer << ch << p.first << ':';
1069 p.second.print(printer.getStream());
1070 ch = ',';
1071 }
1072 printer << ')';
1073 }
1074 if (getTypeList().size()) {
1075 if (isPacked()) {
1076 printer << '<';
1077 }
1078 char ch = '{';
1079 for (auto p : getTypeList()) {
1080 printer << ch << p.first << ':';
1081 p.second.print(printer.getStream());
1082 ch = ',';
1083 }
1084 printer << '}';
1085 if (isPacked()) {
1086 printer << '>';
1087 }
1088 }
1089 recordTypeVisited.erase(uniqueKey());
1090 }
1091 printer << '>';
1092}
1093
1094void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
1095 llvm::ArrayRef<TypePair> typeList) {
1096 getImpl()->finalize(lenPList, typeList);
1097}
1098
1099llvm::StringRef fir::RecordType::getName() const {
1100 return getImpl()->getName();
1101}
1102
1103RecordType::TypeList fir::RecordType::getTypeList() const {
1104 return getImpl()->getTypeList();
1105}
1106
1107RecordType::TypeList fir::RecordType::getLenParamList() const {
1108 return getImpl()->getLenParamList();
1109}
1110
1111bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); }
1112
1113void fir::RecordType::pack(bool p) { getImpl()->pack(p); }
1114
1115bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); }
1116
1117detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
1118 return getImpl();
1119}
1120
1121llvm::LogicalResult fir::RecordType::verify(
1122 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1123 llvm::StringRef name) {
1124 if (name.size() == 0)
1125 return emitError() << "record types must have a name";
1126 return mlir::success();
1127}
1128
1129mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
1130 for (auto f : getTypeList())
1131 if (ident == f.first)
1132 return f.second;
1133 return {};
1134}
1135
1136unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
1137 for (auto f : llvm::enumerate(getTypeList()))
1138 if (ident == f.value().first)
1139 return f.index();
1140 return std::numeric_limits<unsigned>::max();
1141}
1142
1143//===----------------------------------------------------------------------===//
1144// ReferenceType
1145//===----------------------------------------------------------------------===//
1146
1147// `ref` `<` type (`, volatile` $volatile^)? `>`
1148mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
1149 auto location = parser.getCurrentLocation();
1150 auto *context = parser.getContext();
1151 mlir::Type eleTy;
1152 bool isVolatile = false;
1153 if (parser.parseLess() || parser.parseType(eleTy))
1154 return {};
1155 if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
1156 return {};
1157 if (parser.parseGreater())
1158 return {};
1159 return parser.getChecked<fir::ReferenceType>(location, context, eleTy,
1160 isVolatile);
1161}
1162
1163void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
1164 printer << "<" << getEleTy();
1165 if (isVolatile())
1166 printer << ", " << getVolatileKeyword();
1167 printer << '>';
1168}
1169
1170llvm::LogicalResult fir::ReferenceType::verify(
1171 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type eleTy,
1172 bool isVolatile) {
1173 if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
1174 ReferenceType, TypeDescType>(eleTy))
1175 return emitError() << "cannot build a reference to type: " << eleTy << '\n';
1176 return mlir::success();
1177}
1178
1179//===----------------------------------------------------------------------===//
1180// SequenceType
1181//===----------------------------------------------------------------------===//
1182
1183// `array` `<` `*` | bounds (`x` bounds)* `:` type (',' affine-map)? `>`
1184// bounds ::= `?` | int-lit
1185mlir::Type fir::SequenceType::parse(mlir::AsmParser &parser) {
1186 if (parser.parseLess())
1187 return {};
1188 SequenceType::Shape shape;
1189 if (parser.parseOptionalStar()) {
1190 if (parser.parseDimensionList(shape, /*allowDynamic=*/true))
1191 return {};
1192 } else if (parser.parseColon()) {
1193 return {};
1194 }
1195 mlir::Type eleTy;
1196 if (parser.parseType(eleTy))
1197 return {};
1198 mlir::AffineMapAttr map;
1199 if (!parser.parseOptionalComma()) {
1200 if (parser.parseAttribute(map)) {
1201 parser.emitError(parser.getNameLoc(), "expecting affine map");
1202 return {};
1203 }
1204 }
1205 if (parser.parseGreater())
1206 return {};
1207 return SequenceType::get(parser.getContext(), shape, eleTy, map);
1208}
1209
1210void fir::SequenceType::print(mlir::AsmPrinter &printer) const {
1211 auto shape = getShape();
1212 if (shape.size()) {
1213 printer << '<';
1214 for (const auto &b : shape) {
1215 if (b >= 0)
1216 printer << b << 'x';
1217 else
1218 printer << "?x";
1219 }
1220 } else {
1221 printer << "<*:";
1222 }
1223 printer << getEleTy();
1224 if (auto map = getLayoutMap()) {
1225 printer << ", ";
1226 map.print(printer.getStream());
1227 }
1228 printer << '>';
1229}
1230
1231unsigned fir::SequenceType::getConstantRows() const {
1232 if (hasDynamicSize(getEleTy()))
1233 return 0;
1234 auto shape = getShape();
1235 unsigned count = 0;
1236 for (auto d : shape) {
1237 if (d == getUnknownExtent())
1238 break;
1239 ++count;
1240 }
1241 return count;
1242}
1243
1244llvm::LogicalResult fir::SequenceType::verify(
1245 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1246 llvm::ArrayRef<int64_t> shape, mlir::Type eleTy,
1247 mlir::AffineMapAttr layoutMap) {
1248 // DIMENSION attribute can only be applied to an intrinsic or record type
1249 if (mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
1250 ShiftType, SliceType, FieldType, LenType, HeapType, PointerType,
1251 ReferenceType, TypeDescType, SequenceType>(eleTy))
1252 return emitError() << "cannot build an array of this element type: "
1253 << eleTy << '\n';
1254 return mlir::success();
1255}
1256
1257//===----------------------------------------------------------------------===//
1258// ShapeType
1259//===----------------------------------------------------------------------===//
1260
1261mlir::Type fir::ShapeType::parse(mlir::AsmParser &parser) {
1262 return parseRankSingleton<fir::ShapeType>(parser);
1263}
1264
1265void fir::ShapeType::print(mlir::AsmPrinter &printer) const {
1266 printer << "<" << getImpl()->rank << ">";
1267}
1268
1269//===----------------------------------------------------------------------===//
1270// ShapeShiftType
1271//===----------------------------------------------------------------------===//
1272
1273mlir::Type fir::ShapeShiftType::parse(mlir::AsmParser &parser) {
1274 return parseRankSingleton<fir::ShapeShiftType>(parser);
1275}
1276
1277void fir::ShapeShiftType::print(mlir::AsmPrinter &printer) const {
1278 printer << "<" << getRank() << ">";
1279}
1280
1281//===----------------------------------------------------------------------===//
1282// ShiftType
1283//===----------------------------------------------------------------------===//
1284
1285mlir::Type fir::ShiftType::parse(mlir::AsmParser &parser) {
1286 return parseRankSingleton<fir::ShiftType>(parser);
1287}
1288
1289void fir::ShiftType::print(mlir::AsmPrinter &printer) const {
1290 printer << "<" << getRank() << ">";
1291}
1292
1293//===----------------------------------------------------------------------===//
1294// SliceType
1295//===----------------------------------------------------------------------===//
1296
1297// `slice` `<` rank `>`
1298mlir::Type fir::SliceType::parse(mlir::AsmParser &parser) {
1299 return parseRankSingleton<fir::SliceType>(parser);
1300}
1301
1302void fir::SliceType::print(mlir::AsmPrinter &printer) const {
1303 printer << "<" << getRank() << '>';
1304}
1305
1306//===----------------------------------------------------------------------===//
1307// TypeDescType
1308//===----------------------------------------------------------------------===//
1309
1310// `tdesc` `<` type `>`
1311mlir::Type fir::TypeDescType::parse(mlir::AsmParser &parser) {
1312 return parseTypeSingleton<fir::TypeDescType>(parser);
1313}
1314
1315void fir::TypeDescType::print(mlir::AsmPrinter &printer) const {
1316 printer << "<" << getOfTy() << '>';
1317}
1318
1319llvm::LogicalResult fir::TypeDescType::verify(
1320 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1321 mlir::Type eleTy) {
1322 if (mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
1323 ShiftType, SliceType, FieldType, LenType, ReferenceType,
1324 TypeDescType>(eleTy))
1325 return emitError() << "cannot build a type descriptor of type: " << eleTy
1326 << '\n';
1327 return mlir::success();
1328}
1329
1330//===----------------------------------------------------------------------===//
1331// VectorType
1332//===----------------------------------------------------------------------===//
1333
1334// `vector` `<` len `:` type `>`
1335mlir::Type fir::VectorType::parse(mlir::AsmParser &parser) {
1336 int64_t len = 0;
1337 mlir::Type eleTy;
1338 if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() ||
1339 parser.parseType(eleTy) || parser.parseGreater())
1340 return {};
1341 return fir::VectorType::get(len, eleTy);
1342}
1343
1344void fir::VectorType::print(mlir::AsmPrinter &printer) const {
1345 printer << "<" << getLen() << ':' << getEleTy() << '>';
1346}
1347
1348llvm::LogicalResult fir::VectorType::verify(
1349 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
1350 mlir::Type eleTy) {
1351 if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
1352 return emitError() << "cannot build a vector of type " << eleTy << '\n';
1353 return mlir::success();
1354}
1355
1356bool fir::VectorType::isValidElementType(mlir::Type t) {
1357 return isa_real(t) || isa_integer(t);
1358}
1359
1360bool fir::isCharacterProcedureTuple(mlir::Type ty, bool acceptRawFunc) {
1361 mlir::TupleType tuple = mlir::dyn_cast<mlir::TupleType>(ty);
1362 return tuple && tuple.size() == 2 &&
1363 (mlir::isa<fir::BoxProcType>(tuple.getType(0)) ||
1364 (acceptRawFunc && mlir::isa<mlir::FunctionType>(tuple.getType(0)))) &&
1365 fir::isa_integer(tuple.getType(1));
1366}
1367
1368bool fir::hasAbstractResult(mlir::FunctionType ty) {
1369 if (ty.getNumResults() == 0)
1370 return false;
1371 auto resultType = ty.getResult(0);
1372 return mlir::isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>(
1373 resultType);
1374}
1375
1376/// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error
1377/// messages only.
1378mlir::Type fir::fromRealTypeID(mlir::MLIRContext *context,
1379 llvm::Type::TypeID typeID, fir::KindTy kind) {
1380 switch (typeID) {
1381 case llvm::Type::TypeID::HalfTyID:
1382 return mlir::Float16Type::get(context);
1383 case llvm::Type::TypeID::BFloatTyID:
1384 return mlir::BFloat16Type::get(context);
1385 case llvm::Type::TypeID::FloatTyID:
1386 return mlir::Float32Type::get(context);
1387 case llvm::Type::TypeID::DoubleTyID:
1388 return mlir::Float64Type::get(context);
1389 case llvm::Type::TypeID::X86_FP80TyID:
1390 return mlir::Float80Type::get(context);
1391 case llvm::Type::TypeID::FP128TyID:
1392 return mlir::Float128Type::get(context);
1393 default:
1394 mlir::emitError(mlir::UnknownLoc::get(context))
1395 << "unsupported type: !fir.real<" << kind << ">";
1396 return {};
1397 }
1398}
1399
1400//===----------------------------------------------------------------------===//
1401// BaseBoxType
1402//===----------------------------------------------------------------------===//
1403
1404mlir::Type BaseBoxType::getEleTy() const {
1405 return llvm::TypeSwitch<fir::BaseBoxType, mlir::Type>(*this)
1406 .Case<fir::BoxType, fir::ClassType>(
1407 [](auto type) { return type.getEleTy(); });
1408}
1409
1410mlir::Type BaseBoxType::getBaseAddressType() const {
1411 mlir::Type eleTy = getEleTy();
1412 if (fir::isa_ref_type(eleTy))
1413 return eleTy;
1414 return fir::ReferenceType::get(eleTy, isVolatile());
1415}
1416
1417mlir::Type BaseBoxType::unwrapInnerType() const {
1418 return fir::unwrapInnerType(getEleTy());
1419}
1420
1421static mlir::Type
1422changeTypeShape(mlir::Type type,
1423 std::optional<fir::SequenceType::ShapeRef> newShape) {
1424 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
1425 .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
1426 if (newShape)
1427 return fir::SequenceType::get(*newShape, seqTy.getEleTy());
1428 return seqTy.getEleTy();
1429 })
1430 .Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(
1431 [&](auto t) -> mlir::Type {
1432 using FIRT = decltype(t);
1433 return FIRT::get(changeTypeShape(t.getEleTy(), newShape),
1434 t.isVolatile());
1435 })
1436 .Case<fir::PointerType, fir::HeapType>([&](auto t) -> mlir::Type {
1437 using FIRT = decltype(t);
1438 return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
1439 })
1440 .Default([&](mlir::Type t) -> mlir::Type {
1441 assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
1442 llvm::isa<mlir::NoneType>(t) ||
1443 llvm::isa<fir::CharacterType>(t)) &&
1444 "unexpected FIR leaf type");
1445 if (newShape)
1446 return fir::SequenceType::get(*newShape, t);
1447 return t;
1448 });
1449}
1450
1451fir::BaseBoxType
1452fir::BaseBoxType::getBoxTypeWithNewShape(mlir::Type shapeMold) const {
1453 fir::SequenceType seqTy = fir::unwrapUntilSeqType(shapeMold);
1454 std::optional<fir::SequenceType::ShapeRef> newShape;
1455 if (seqTy)
1456 newShape = seqTy.getShape();
1457 return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
1458}
1459
1460fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewShape(int rank) const {
1461 std::optional<fir::SequenceType::ShapeRef> newShape;
1462 fir::SequenceType::Shape shapeVector;
1463 if (rank > 0) {
1464 shapeVector =
1465 fir::SequenceType::Shape(rank, fir::SequenceType::getUnknownExtent());
1466 newShape = shapeVector;
1467 }
1468 return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
1469}
1470
1471fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewAttr(
1472 fir::BaseBoxType::Attribute attr) const {
1473 mlir::Type baseType = fir::unwrapRefType(getEleTy());
1474 switch (attr) {
1475 case fir::BaseBoxType::Attribute::None:
1476 break;
1477 case fir::BaseBoxType::Attribute::Allocatable:
1478 baseType = fir::HeapType::get(baseType);
1479 break;
1480 case fir::BaseBoxType::Attribute::Pointer:
1481 baseType = fir::PointerType::get(baseType);
1482 break;
1483 }
1484 return llvm::TypeSwitch<fir::BaseBoxType, fir::BaseBoxType>(*this)
1485 .Case<fir::BoxType>([baseType](auto b) {
1486 return fir::BoxType::get(baseType, b.isVolatile());
1487 })
1488 .Case<fir::ClassType>([baseType](auto b) {
1489 return fir::ClassType::get(baseType, b.isVolatile());
1490 });
1491}
1492
1493bool fir::BaseBoxType::isAssumedRank() const {
1494 if (auto seqTy =
1495 mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(getEleTy())))
1496 return seqTy.hasUnknownShape();
1497 return false;
1498}
1499
1500bool fir::BaseBoxType::isPointer() const {
1501 return llvm::isa<fir::PointerType>(getEleTy());
1502}
1503
1504bool fir::BaseBoxType::isPointerOrAllocatable() const {
1505 return llvm::isa<fir::PointerType, fir::HeapType>(getEleTy());
1506}
1507
1508bool BaseBoxType::isVolatile() const { return fir::isa_volatile_type(*this); }
1509
1510//===----------------------------------------------------------------------===//
1511// FIROpsDialect
1512//===----------------------------------------------------------------------===//
1513
1514void FIROpsDialect::registerTypes() {
1515 addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, ClassType,
1516 FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
1517 LLVMPointerType, PointerType, RecordType, ReferenceType,
1518 SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType,
1519 TypeDescType, fir::VectorType, fir::DummyScopeType>();
1520 fir::ReferenceType::attachInterface<
1521 OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext());
1522 fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>(
1523 *getContext());
1524 fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>(
1525 *getContext());
1526 fir::LLVMPointerType::attachInterface<
1527 OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext());
1528}
1529
1530std::optional<std::pair<uint64_t, unsigned short>>
1531fir::getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
1532 const mlir::DataLayout &dl,
1533 const fir::KindMapping &kindMap) {
1534 if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
1535 llvm::TypeSize size = dl.getTypeSize(ty);
1536 unsigned short alignment = dl.getTypeABIAlignment(ty);
1537 return std::pair{size, alignment};
1538 }
1539 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1540 auto result = getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
1541 if (!result)
1542 return result;
1543 auto [eleSize, eleAlign] = *result;
1544 std::uint64_t size =
1545 llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
1546 return std::pair{size, eleAlign};
1547 }
1548 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
1549 std::uint64_t size = 0;
1550 unsigned short align = 1;
1551 for (auto component : recTy.getTypeList()) {
1552 auto result = getTypeSizeAndAlignment(loc, component.second, dl, kindMap);
1553 if (!result)
1554 return result;
1555 auto [compSize, compAlign] = *result;
1556 size =
1557 llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
1558 align = std::max(align, compAlign);
1559 }
1560 return std::pair{size, align};
1561 }
1562 if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
1563 mlir::Type intTy = mlir::IntegerType::get(
1564 logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
1565 return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1566 }
1567 if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
1568 mlir::Type intTy = mlir::IntegerType::get(
1569 character.getContext(),
1570 kindMap.getCharacterBitsize(character.getFKind()));
1571 auto result = getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1572 if (!result)
1573 return result;
1574 auto [compSize, compAlign] = *result;
1575 if (character.hasConstantLen())
1576 compSize *= character.getLen();
1577 return std::pair{compSize, compAlign};
1578 }
1579 return std::nullopt;
1580}
1581
1582std::pair<std::uint64_t, unsigned short>
1583fir::getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty,
1584 const mlir::DataLayout &dl,
1585 const fir::KindMapping &kindMap) {
1586 std::optional<std::pair<uint64_t, unsigned short>> result =
1587 getTypeSizeAndAlignment(loc, ty, dl, kindMap);
1588 if (result)
1589 return *result;
1590 TODO(loc, "computing size of a component");
1591}
1592

source code of flang/lib/Optimizer/Dialect/FIRType.cpp