1//===-- FIRType.cpp -------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/Dialect/FIRType.h"
14#include "flang/ISO_Fortran_binding_wrapper.h"
15#include "flang/Optimizer/Builder/Todo.h"
16#include "flang/Optimizer/Dialect/FIRDialect.h"
17#include "flang/Optimizer/Dialect/Support/KindMapping.h"
18#include "flang/Tools/PointerModels.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringSet.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/ErrorHandling.h"
27
28#define GET_TYPEDEF_CLASSES
29#include "flang/Optimizer/Dialect/FIROpsTypes.cpp.inc"
30
31using namespace fir;
32
33namespace {
34
35template <typename TYPE>
36TYPE parseIntSingleton(mlir::AsmParser &parser) {
37 int kind = 0;
38 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseGreater())
39 return {};
40 return TYPE::get(parser.getContext(), kind);
41}
42
43template <typename TYPE>
44TYPE parseKindSingleton(mlir::AsmParser &parser) {
45 return parseIntSingleton<TYPE>(parser);
46}
47
48template <typename TYPE>
49TYPE parseRankSingleton(mlir::AsmParser &parser) {
50 return parseIntSingleton<TYPE>(parser);
51}
52
53template <typename TYPE>
54TYPE parseTypeSingleton(mlir::AsmParser &parser) {
55 mlir::Type ty;
56 if (parser.parseLess() || parser.parseType(result&: ty) || parser.parseGreater())
57 return {};
58 return TYPE::get(ty);
59}
60
61/// Is `ty` a standard or FIR integer type?
62static bool isaIntegerType(mlir::Type ty) {
63 // TODO: why aren't we using isa_integer? investigatation required.
64 return ty.isa<mlir::IntegerType>() || ty.isa<fir::IntegerType>();
65}
66
67bool verifyRecordMemberType(mlir::Type ty) {
68 return !(ty.isa<BoxCharType>() || ty.isa<ShapeType>() ||
69 ty.isa<ShapeShiftType>() || ty.isa<ShiftType>() ||
70 ty.isa<SliceType>() || ty.isa<FieldType>() || ty.isa<LenType>() ||
71 ty.isa<ReferenceType>() || ty.isa<TypeDescType>());
72}
73
74bool verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1,
75 llvm::ArrayRef<RecordType::TypePair> a2) {
76 // FIXME: do we need to allow for any variance here?
77 return a1 == a2;
78}
79
80RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
81 llvm::ArrayRef<RecordType::TypePair> lenPList,
82 llvm::ArrayRef<RecordType::TypePair> typeList) {
83 auto loc = parser.getNameLoc();
84 if (!verifySameLists(derivedTy.getLenParamList(), lenPList) ||
85 !verifySameLists(derivedTy.getTypeList(), typeList)) {
86 parser.emitError(loc, message: "cannot redefine record type members");
87 return {};
88 }
89 for (auto &p : lenPList)
90 if (!isaIntegerType(p.second)) {
91 parser.emitError(loc, "LEN parameter must be integral type");
92 return {};
93 }
94 for (auto &p : typeList)
95 if (!verifyRecordMemberType(p.second)) {
96 parser.emitError(loc, "field parameter has invalid type");
97 return {};
98 }
99 llvm::StringSet<> uniq;
100 for (auto &p : lenPList)
101 if (!uniq.insert(p.first).second) {
102 parser.emitError(loc, "LEN parameter cannot have duplicate name");
103 return {};
104 }
105 for (auto &p : typeList)
106 if (!uniq.insert(p.first).second) {
107 parser.emitError(loc, "field cannot have duplicate name");
108 return {};
109 }
110 return derivedTy;
111}
112
113} // namespace
114
115// Implementation of the thin interface from dialect to type parser
116
117mlir::Type fir::parseFirType(FIROpsDialect *dialect,
118 mlir::DialectAsmParser &parser) {
119 mlir::StringRef typeTag;
120 mlir::Type genType;
121 auto parseResult = generatedTypeParser(parser, &typeTag, genType);
122 if (parseResult.has_value())
123 return genType;
124 parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;
125 return {};
126}
127
128namespace fir {
129namespace detail {
130
131// Type storage classes
132
133/// Derived type storage
134struct RecordTypeStorage : public mlir::TypeStorage {
135 using KeyTy = llvm::StringRef;
136
137 static unsigned hashKey(const KeyTy &key) {
138 return llvm::hash_combine(args: key.str());
139 }
140
141 bool operator==(const KeyTy &key) const { return key == getName(); }
142
143 static RecordTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
144 const KeyTy &key) {
145 auto *storage = allocator.allocate<RecordTypeStorage>();
146 return new (storage) RecordTypeStorage{key};
147 }
148
149 llvm::StringRef getName() const { return name; }
150
151 void setLenParamList(llvm::ArrayRef<RecordType::TypePair> list) {
152 lens = list;
153 }
154 llvm::ArrayRef<RecordType::TypePair> getLenParamList() const { return lens; }
155
156 void setTypeList(llvm::ArrayRef<RecordType::TypePair> list) { types = list; }
157 llvm::ArrayRef<RecordType::TypePair> getTypeList() const { return types; }
158
159 bool isFinalized() const { return finalized; }
160 void finalize(llvm::ArrayRef<RecordType::TypePair> lenParamList,
161 llvm::ArrayRef<RecordType::TypePair> typeList) {
162 if (finalized)
163 return;
164 finalized = true;
165 setLenParamList(lenParamList);
166 setTypeList(typeList);
167 }
168
169protected:
170 std::string name;
171 bool finalized;
172 std::vector<RecordType::TypePair> lens;
173 std::vector<RecordType::TypePair> types;
174
175private:
176 RecordTypeStorage() = delete;
177 explicit RecordTypeStorage(llvm::StringRef name)
178 : name{name}, finalized{false} {}
179};
180
181} // namespace detail
182
183template <typename A, typename B>
184bool inbounds(A v, B lb, B ub) {
185 return v >= lb && v < ub;
186}
187
188bool isa_fir_type(mlir::Type t) {
189 return llvm::isa<FIROpsDialect>(t.getDialect());
190}
191
192bool isa_std_type(mlir::Type t) {
193 return llvm::isa<mlir::BuiltinDialect>(Val: t.getDialect());
194}
195
196bool isa_fir_or_std_type(mlir::Type t) {
197 if (auto funcType = t.dyn_cast<mlir::FunctionType>())
198 return llvm::all_of(funcType.getInputs(), isa_fir_or_std_type) &&
199 llvm::all_of(funcType.getResults(), isa_fir_or_std_type);
200 return isa_fir_type(t) || isa_std_type(t);
201}
202
203mlir::Type getDerivedType(mlir::Type ty) {
204 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty)
205 .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto p) {
206 if (auto seq = p.getEleTy().template dyn_cast<fir::SequenceType>())
207 return seq.getEleTy();
208 return p.getEleTy();
209 })
210 .Default([](mlir::Type t) { return t; });
211}
212
213mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
214 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
215 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
216 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
217 .Default([](mlir::Type) { return mlir::Type{}; });
218}
219
220mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
221 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
222 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
223 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
224 .Case<fir::BaseBoxType>(
225 [](auto p) { return unwrapRefType(p.getEleTy()); })
226 .Default([](mlir::Type) { return mlir::Type{}; });
227}
228
229static bool hasDynamicSize(fir::RecordType recTy) {
230 for (auto field : recTy.getTypeList()) {
231 if (auto arr = field.second.dyn_cast<fir::SequenceType>()) {
232 if (sequenceWithNonConstantShape(arr))
233 return true;
234 } else if (characterWithDynamicLen(field.second)) {
235 return true;
236 } else if (auto rec = field.second.dyn_cast<fir::RecordType>()) {
237 if (hasDynamicSize(rec))
238 return true;
239 }
240 }
241 return false;
242}
243
244bool hasDynamicSize(mlir::Type t) {
245 if (auto arr = t.dyn_cast<fir::SequenceType>()) {
246 if (sequenceWithNonConstantShape(arr))
247 return true;
248 t = arr.getEleTy();
249 }
250 if (characterWithDynamicLen(t))
251 return true;
252 if (auto rec = t.dyn_cast<fir::RecordType>())
253 return hasDynamicSize(rec);
254 return false;
255}
256
257mlir::Type extractSequenceType(mlir::Type ty) {
258 if (mlir::isa<fir::SequenceType>(ty))
259 return ty;
260 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
261 return extractSequenceType(boxTy.getEleTy());
262 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
263 return extractSequenceType(heapTy.getEleTy());
264 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
265 return extractSequenceType(ptrTy.getEleTy());
266 return mlir::Type{};
267}
268
269bool isPointerType(mlir::Type ty) {
270 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
271 ty = refTy;
272 if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>())
273 return boxTy.getEleTy().isa<fir::PointerType>();
274 return false;
275}
276
277bool isAllocatableType(mlir::Type ty) {
278 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
279 ty = refTy;
280 if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>())
281 return boxTy.getEleTy().isa<fir::HeapType>();
282 return false;
283}
284
285bool isBoxNone(mlir::Type ty) {
286 if (auto box = ty.dyn_cast<fir::BoxType>())
287 return box.getEleTy().isa<mlir::NoneType>();
288 return false;
289}
290
291bool isBoxedRecordType(mlir::Type ty) {
292 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
293 ty = refTy;
294 if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
295 if (boxTy.getEleTy().isa<fir::RecordType>())
296 return true;
297 mlir::Type innerType = boxTy.unwrapInnerType();
298 return innerType && innerType.isa<fir::RecordType>();
299 }
300 return false;
301}
302
303bool isScalarBoxedRecordType(mlir::Type ty) {
304 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
305 ty = refTy;
306 if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
307 if (boxTy.getEleTy().isa<fir::RecordType>())
308 return true;
309 if (auto heapTy = boxTy.getEleTy().dyn_cast<fir::HeapType>())
310 return heapTy.getEleTy().isa<fir::RecordType>();
311 if (auto ptrTy = boxTy.getEleTy().dyn_cast<fir::PointerType>())
312 return ptrTy.getEleTy().isa<fir::RecordType>();
313 }
314 return false;
315}
316
317bool isAssumedType(mlir::Type ty) {
318 // Rule out CLASS(*) which are `fir.class<[fir.array] none>`.
319 if (mlir::isa<fir::ClassType>(ty))
320 return false;
321 mlir::Type valueType = fir::unwrapPassByRefType(fir::unwrapRefType(ty));
322 // Refuse raw `none` or `fir.array<none>` since assumed type
323 // should be in memory variables.
324 if (valueType == ty)
325 return false;
326 mlir::Type inner = fir::unwrapSequenceType(valueType);
327 return mlir::isa<mlir::NoneType>(Val: inner);
328}
329
330bool isAssumedShape(mlir::Type ty) {
331 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty))
332 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
333 return seqTy.hasDynamicExtents();
334 return false;
335}
336
337bool isAllocatableOrPointerArray(mlir::Type ty) {
338 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
339 ty = refTy;
340 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
341 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
342 return mlir::isa<fir::SequenceType>(heapTy.getEleTy());
343 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
344 return mlir::isa<fir::SequenceType>(ptrTy.getEleTy());
345 }
346 return false;
347}
348
349bool isTypeWithDescriptor(mlir::Type ty) {
350 if (mlir::isa<fir::BaseBoxType>(unwrapRefType(ty)))
351 return true;
352 return false;
353}
354
355bool isPolymorphicType(mlir::Type ty) {
356 // CLASS(T) or CLASS(*)
357 if (mlir::isa<fir::ClassType>(fir::unwrapRefType(ty)))
358 return true;
359 // assumed type are polymorphic.
360 return isAssumedType(ty);
361}
362
363bool isUnlimitedPolymorphicType(mlir::Type ty) {
364 // CLASS(*)
365 if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) {
366 if (clTy.getEleTy().isa<mlir::NoneType>())
367 return true;
368 mlir::Type innerType = clTy.unwrapInnerType();
369 return innerType && innerType.isa<mlir::NoneType>();
370 }
371 // TYPE(*)
372 return isAssumedType(ty);
373}
374
375mlir::Type unwrapInnerType(mlir::Type ty) {
376 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty)
377 .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto t) {
378 mlir::Type eleTy = t.getEleTy();
379 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
380 return seqTy.getEleTy();
381 return eleTy;
382 })
383 .Case<fir::RecordType>([](auto t) { return t; })
384 .Default([](mlir::Type) { return mlir::Type{}; });
385}
386
387bool isRecordWithAllocatableMember(mlir::Type ty) {
388 if (auto recTy = ty.dyn_cast<fir::RecordType>())
389 for (auto [field, memTy] : recTy.getTypeList()) {
390 if (fir::isAllocatableType(memTy))
391 return true;
392 // A record type cannot recursively include itself as a direct member.
393 // There must be an intervening `ptr` type, so recursion is safe here.
394 if (memTy.isa<fir::RecordType>() && isRecordWithAllocatableMember(memTy))
395 return true;
396 }
397 return false;
398}
399
400bool isRecordWithDescriptorMember(mlir::Type ty) {
401 ty = unwrapSequenceType(ty);
402 if (auto recTy = ty.dyn_cast<fir::RecordType>())
403 for (auto [field, memTy] : recTy.getTypeList()) {
404 if (mlir::isa<fir::BaseBoxType>(memTy))
405 return true;
406 if (memTy.isa<fir::RecordType>() && isRecordWithDescriptorMember(memTy))
407 return true;
408 }
409 return false;
410}
411
412mlir::Type unwrapAllRefAndSeqType(mlir::Type ty) {
413 while (true) {
414 mlir::Type nt = unwrapSequenceType(unwrapRefType(ty));
415 if (auto vecTy = nt.dyn_cast<fir::VectorType>())
416 nt = vecTy.getEleTy();
417 if (nt == ty)
418 return ty;
419 ty = nt;
420 }
421}
422
423mlir::Type unwrapSeqOrBoxedSeqType(mlir::Type ty) {
424 if (auto seqTy = ty.dyn_cast<fir::SequenceType>())
425 return seqTy.getEleTy();
426 if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
427 auto eleTy = unwrapRefType(boxTy.getEleTy());
428 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
429 return seqTy.getEleTy();
430 }
431 return ty;
432}
433
434unsigned getBoxRank(mlir::Type boxTy) {
435 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(t: boxTy);
436 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
437 return seqTy.getDimension();
438 return 0;
439}
440
441/// Return the ISO_C_BINDING intrinsic module value of type \p ty.
442int getTypeCode(mlir::Type ty, const fir::KindMapping &kindMap) {
443 unsigned width = 0;
444 if (mlir::IntegerType intTy = ty.dyn_cast<mlir::IntegerType>()) {
445 switch (intTy.getWidth()) {
446 case 8:
447 return CFI_type_int8_t;
448 case 16:
449 return CFI_type_int16_t;
450 case 32:
451 return CFI_type_int32_t;
452 case 64:
453 return CFI_type_int64_t;
454 case 128:
455 return CFI_type_int128_t;
456 }
457 llvm_unreachable("unsupported integer type");
458 }
459 if (fir::LogicalType logicalTy = ty.dyn_cast<fir::LogicalType>()) {
460 switch (kindMap.getLogicalBitsize(logicalTy.getFKind())) {
461 case 8:
462 return CFI_type_Bool;
463 case 16:
464 return CFI_type_int_least16_t;
465 case 32:
466 return CFI_type_int_least32_t;
467 case 64:
468 return CFI_type_int_least64_t;
469 }
470 llvm_unreachable("unsupported logical type");
471 }
472 if (mlir::FloatType floatTy = ty.dyn_cast<mlir::FloatType>()) {
473 switch (floatTy.getWidth()) {
474 case 16:
475 return floatTy.isBF16() ? CFI_type_bfloat : CFI_type_half_float;
476 case 32:
477 return CFI_type_float;
478 case 64:
479 return CFI_type_double;
480 case 80:
481 return CFI_type_extended_double;
482 case 128:
483 return CFI_type_float128;
484 }
485 llvm_unreachable("unsupported real type");
486 }
487 if (fir::isa_complex(ty)) {
488 if (mlir::ComplexType complexTy = ty.dyn_cast<mlir::ComplexType>()) {
489 mlir::FloatType floatTy =
490 complexTy.getElementType().cast<mlir::FloatType>();
491 if (floatTy.isBF16())
492 return CFI_type_bfloat_Complex;
493 width = floatTy.getWidth();
494 } else if (fir::ComplexType complexTy = ty.dyn_cast<fir::ComplexType>()) {
495 auto FKind = complexTy.getFKind();
496 if (FKind == 3)
497 return CFI_type_bfloat_Complex;
498 width = kindMap.getRealBitsize(FKind);
499 }
500 switch (width) {
501 case 16:
502 return CFI_type_half_float_Complex;
503 case 32:
504 return CFI_type_float_Complex;
505 case 64:
506 return CFI_type_double_Complex;
507 case 80:
508 return CFI_type_extended_double_Complex;
509 case 128:
510 return CFI_type_float128_Complex;
511 }
512 llvm_unreachable("unsupported complex size");
513 }
514 if (fir::CharacterType charTy = ty.dyn_cast<fir::CharacterType>()) {
515 switch (kindMap.getCharacterBitsize(charTy.getFKind())) {
516 case 8:
517 return CFI_type_char;
518 case 16:
519 return CFI_type_char16_t;
520 case 32:
521 return CFI_type_char32_t;
522 }
523 llvm_unreachable("unsupported character type");
524 }
525 if (fir::isa_ref_type(ty))
526 return CFI_type_cptr;
527 if (ty.isa<fir::RecordType>())
528 return CFI_type_struct;
529 llvm_unreachable("unsupported type");
530}
531
532std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap,
533 llvm::StringRef prefix) {
534 std::string buf;
535 llvm::raw_string_ostream name{buf};
536 name << prefix.str();
537 if (!prefix.empty())
538 name << "_";
539 while (ty) {
540 if (fir::isa_trivial(ty)) {
541 if (mlir::isa<mlir::IndexType>(Val: ty)) {
542 name << "idx";
543 } else if (ty.isIntOrIndex()) {
544 name << 'i' << ty.getIntOrFloatBitWidth();
545 } else if (ty.isa<mlir::FloatType>()) {
546 name << 'f' << ty.getIntOrFloatBitWidth();
547 } else if (fir::isa_complex(ty)) {
548 name << 'z';
549 if (auto cplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) {
550 auto floatTy = cplxTy.getElementType().cast<mlir::FloatType>();
551 name << floatTy.getWidth();
552 } else if (auto cplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty)) {
553 name << kindMap.getRealBitsize(cplxTy.getFKind());
554 }
555 } else if (auto logTy = mlir::dyn_cast_or_null<fir::LogicalType>(ty)) {
556 name << 'l' << kindMap.getLogicalBitsize(logTy.getFKind());
557 } else {
558 llvm::report_fatal_error(reason: "unsupported type");
559 }
560 break;
561 } else if (mlir::isa<mlir::NoneType>(Val: ty)) {
562 name << "none";
563 break;
564 } else if (auto charTy = mlir::dyn_cast_or_null<fir::CharacterType>(ty)) {
565 name << 'c' << kindMap.getCharacterBitsize(charTy.getFKind());
566 if (charTy.getLen() == fir::CharacterType::unknownLen())
567 name << "xU";
568 else if (charTy.getLen() != fir::CharacterType::singleton())
569 name << "x" << charTy.getLen();
570 break;
571 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
572 for (auto extent : seqTy.getShape()) {
573 if (extent == fir::SequenceType::getUnknownExtent())
574 name << "Ux";
575 else
576 name << extent << 'x';
577 }
578 ty = seqTy.getEleTy();
579 } else if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
580 name << "ref_";
581 ty = refTy.getEleTy();
582 } else if (auto ptrTy = mlir::dyn_cast_or_null<fir::PointerType>(ty)) {
583 name << "ptr_";
584 ty = ptrTy.getEleTy();
585 } else if (auto ptrTy = mlir::dyn_cast_or_null<fir::LLVMPointerType>(ty)) {
586 name << "llvmptr_";
587 ty = ptrTy.getEleTy();
588 } else if (auto heapTy = mlir::dyn_cast_or_null<fir::HeapType>(ty)) {
589 name << "heap_";
590 ty = heapTy.getEleTy();
591 } else if (auto classTy = mlir::dyn_cast_or_null<fir::ClassType>(ty)) {
592 name << "class_";
593 ty = classTy.getEleTy();
594 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BoxType>(ty)) {
595 name << "box_";
596 ty = boxTy.getEleTy();
597 } else if (auto boxcharTy = mlir::dyn_cast_or_null<fir::BoxCharType>(ty)) {
598 name << "boxchar_";
599 ty = boxcharTy.getEleTy();
600 } else if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(ty)) {
601 name << "rec_" << recTy.getName();
602 break;
603 } else {
604 llvm::report_fatal_error(reason: "unsupported type");
605 }
606 }
607 return name.str();
608}
609
610mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
611 bool turnBoxIntoClass) {
612 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
613 .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
614 return fir::SequenceType::get(seqTy.getShape(), newElementType);
615 })
616 .Case<fir::PointerType, fir::HeapType, fir::ReferenceType,
617 fir::ClassType>([&](auto t) -> mlir::Type {
618 using FIRT = decltype(t);
619 return FIRT::get(
620 changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass));
621 })
622 .Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
623 mlir::Type newInnerType =
624 changeElementType(t.getEleTy(), newElementType, false);
625 if (turnBoxIntoClass)
626 return fir::ClassType::get(newInnerType);
627 return fir::BoxType::get(newInnerType);
628 })
629 .Default([&](mlir::Type t) -> mlir::Type {
630 assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
631 llvm::isa<mlir::NoneType>(t)) &&
632 "unexpected FIR leaf type");
633 return newElementType;
634 });
635}
636
637} // namespace fir
638
639namespace {
640
641static llvm::SmallPtrSet<detail::RecordTypeStorage const *, 4>
642 recordTypeVisited;
643
644} // namespace
645
646void fir::verifyIntegralType(mlir::Type type) {
647 if (isaIntegerType(ty: type) || type.isa<mlir::IndexType>())
648 return;
649 llvm::report_fatal_error(reason: "expected integral type");
650}
651
652void fir::printFirType(FIROpsDialect *, mlir::Type ty,
653 mlir::DialectAsmPrinter &p) {
654 if (mlir::failed(result: generatedTypePrinter(ty, p)))
655 llvm::report_fatal_error(reason: "unknown type to print");
656}
657
658bool fir::isa_unknown_size_box(mlir::Type t) {
659 if (auto boxTy = t.dyn_cast<fir::BaseBoxType>()) {
660 auto valueType = fir::unwrapPassByRefType(boxTy);
661 if (auto seqTy = valueType.dyn_cast<fir::SequenceType>())
662 if (seqTy.hasUnknownShape())
663 return true;
664 }
665 return false;
666}
667
668//===----------------------------------------------------------------------===//
669// BoxProcType
670//===----------------------------------------------------------------------===//
671
672// `boxproc` `<` return-type `>`
673mlir::Type BoxProcType::parse(mlir::AsmParser &parser) {
674 mlir::Type ty;
675 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
676 return {};
677 return get(parser.getContext(), ty);
678}
679
680void fir::BoxProcType::print(mlir::AsmPrinter &printer) const {
681 printer << "<" << getEleTy() << '>';
682}
683
684mlir::LogicalResult
685BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
686 mlir::Type eleTy) {
687 if (eleTy.isa<mlir::FunctionType>())
688 return mlir::success();
689 if (auto refTy = eleTy.dyn_cast<ReferenceType>())
690 if (refTy.isa<mlir::FunctionType>())
691 return mlir::success();
692 return emitError() << "invalid type for boxproc" << eleTy << '\n';
693}
694
695static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) {
696 return eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
697 SliceType, FieldType, LenType, HeapType, PointerType,
698 ReferenceType, TypeDescType>();
699}
700
701//===----------------------------------------------------------------------===//
702// BoxType
703//===----------------------------------------------------------------------===//
704
705mlir::LogicalResult
706fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
707 mlir::Type eleTy) {
708 if (eleTy.isa<fir::BaseBoxType>())
709 return emitError() << "invalid element type\n";
710 // TODO
711 return mlir::success();
712}
713
714//===----------------------------------------------------------------------===//
715// BoxCharType
716//===----------------------------------------------------------------------===//
717
718mlir::Type fir::BoxCharType::parse(mlir::AsmParser &parser) {
719 return parseKindSingleton<fir::BoxCharType>(parser);
720}
721
722void fir::BoxCharType::print(mlir::AsmPrinter &printer) const {
723 printer << "<" << getKind() << ">";
724}
725
726CharacterType
727fir::BoxCharType::getElementType(mlir::MLIRContext *context) const {
728 return CharacterType::getUnknownLen(context, getKind());
729}
730
731CharacterType fir::BoxCharType::getEleTy() const {
732 return getElementType(getContext());
733}
734
735//===----------------------------------------------------------------------===//
736// CharacterType
737//===----------------------------------------------------------------------===//
738
739// `char` `<` kind [`,` `len`] `>`
740mlir::Type fir::CharacterType::parse(mlir::AsmParser &parser) {
741 int kind = 0;
742 if (parser.parseLess() || parser.parseInteger(kind))
743 return {};
744 CharacterType::LenType len = 1;
745 if (mlir::succeeded(parser.parseOptionalComma())) {
746 if (mlir::succeeded(parser.parseOptionalQuestion())) {
747 len = fir::CharacterType::unknownLen();
748 } else if (!mlir::succeeded(parser.parseInteger(len))) {
749 return {};
750 }
751 }
752 if (parser.parseGreater())
753 return {};
754 return get(parser.getContext(), kind, len);
755}
756
757void fir::CharacterType::print(mlir::AsmPrinter &printer) const {
758 printer << "<" << getFKind();
759 auto len = getLen();
760 if (len != fir::CharacterType::singleton()) {
761 printer << ',';
762 if (len == fir::CharacterType::unknownLen())
763 printer << '?';
764 else
765 printer << len;
766 }
767 printer << '>';
768}
769
770//===----------------------------------------------------------------------===//
771// ClassType
772//===----------------------------------------------------------------------===//
773
774mlir::LogicalResult
775fir::ClassType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
776 mlir::Type eleTy) {
777 if (eleTy.isa<fir::RecordType, fir::SequenceType, fir::HeapType,
778 fir::PointerType, mlir::NoneType, mlir::IntegerType,
779 mlir::FloatType, fir::CharacterType, fir::LogicalType,
780 fir::ComplexType, mlir::ComplexType>())
781 return mlir::success();
782 return emitError() << "invalid element type\n";
783}
784
785//===----------------------------------------------------------------------===//
786// ComplexType
787//===----------------------------------------------------------------------===//
788
789mlir::Type fir::ComplexType::parse(mlir::AsmParser &parser) {
790 return parseKindSingleton<fir::ComplexType>(parser);
791}
792
793void fir::ComplexType::print(mlir::AsmPrinter &printer) const {
794 printer << "<" << getFKind() << '>';
795}
796
797mlir::Type fir::ComplexType::getElementType() const {
798 return fir::RealType::get(getContext(), getFKind());
799}
800
801// Return the MLIR float type of the complex element type.
802mlir::Type fir::ComplexType::getEleType(const fir::KindMapping &kindMap) const {
803 auto fkind = getFKind();
804 auto realTypeID = kindMap.getRealTypeID(fkind);
805 return fir::fromRealTypeID(getContext(), realTypeID, fkind);
806}
807
808//===----------------------------------------------------------------------===//
809// HeapType
810//===----------------------------------------------------------------------===//
811
812// `heap` `<` type `>`
813mlir::Type fir::HeapType::parse(mlir::AsmParser &parser) {
814 return parseTypeSingleton<HeapType>(parser);
815}
816
817void fir::HeapType::print(mlir::AsmPrinter &printer) const {
818 printer << "<" << getEleTy() << '>';
819}
820
821mlir::LogicalResult
822fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
823 mlir::Type eleTy) {
824 if (cannotBePointerOrHeapElementType(eleTy))
825 return emitError() << "cannot build a heap pointer to type: " << eleTy
826 << '\n';
827 return mlir::success();
828}
829
830//===----------------------------------------------------------------------===//
831// IntegerType
832//===----------------------------------------------------------------------===//
833
834// `int` `<` kind `>`
835mlir::Type fir::IntegerType::parse(mlir::AsmParser &parser) {
836 return parseKindSingleton<fir::IntegerType>(parser);
837}
838
839void fir::IntegerType::print(mlir::AsmPrinter &printer) const {
840 printer << "<" << getFKind() << '>';
841}
842
843//===----------------------------------------------------------------------===//
844// LogicalType
845//===----------------------------------------------------------------------===//
846
847// `logical` `<` kind `>`
848mlir::Type fir::LogicalType::parse(mlir::AsmParser &parser) {
849 return parseKindSingleton<fir::LogicalType>(parser);
850}
851
852void fir::LogicalType::print(mlir::AsmPrinter &printer) const {
853 printer << "<" << getFKind() << '>';
854}
855
856//===----------------------------------------------------------------------===//
857// PointerType
858//===----------------------------------------------------------------------===//
859
860// `ptr` `<` type `>`
861mlir::Type fir::PointerType::parse(mlir::AsmParser &parser) {
862 return parseTypeSingleton<fir::PointerType>(parser);
863}
864
865void fir::PointerType::print(mlir::AsmPrinter &printer) const {
866 printer << "<" << getEleTy() << '>';
867}
868
869mlir::LogicalResult fir::PointerType::verify(
870 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
871 mlir::Type eleTy) {
872 if (cannotBePointerOrHeapElementType(eleTy))
873 return emitError() << "cannot build a pointer to type: " << eleTy << '\n';
874 return mlir::success();
875}
876
877//===----------------------------------------------------------------------===//
878// RealType
879//===----------------------------------------------------------------------===//
880
881// `real` `<` kind `>`
882mlir::Type fir::RealType::parse(mlir::AsmParser &parser) {
883 return parseKindSingleton<fir::RealType>(parser);
884}
885
886void fir::RealType::print(mlir::AsmPrinter &printer) const {
887 printer << "<" << getFKind() << '>';
888}
889
890mlir::LogicalResult
891fir::RealType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
892 KindTy fKind) {
893 // TODO
894 return mlir::success();
895}
896
897mlir::Type fir::RealType::getFloatType(const fir::KindMapping &kindMap) const {
898 auto fkind = getFKind();
899 auto realTypeID = kindMap.getRealTypeID(fkind);
900 return fir::fromRealTypeID(getContext(), realTypeID, fkind);
901}
902
903//===----------------------------------------------------------------------===//
904// RecordType
905//===----------------------------------------------------------------------===//
906
907// Fortran derived type
908// `type` `<` name
909// (`(` id `:` type (`,` id `:` type)* `)`)?
910// (`{` id `:` type (`,` id `:` type)* `}`)? '>'
911mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
912 llvm::StringRef name;
913 if (parser.parseLess() || parser.parseKeyword(&name))
914 return {};
915 RecordType result = RecordType::get(parser.getContext(), name);
916
917 RecordType::TypeList lenParamList;
918 if (!parser.parseOptionalLParen()) {
919 while (true) {
920 llvm::StringRef lenparam;
921 mlir::Type intTy;
922 if (parser.parseKeyword(&lenparam) || parser.parseColon() ||
923 parser.parseType(intTy)) {
924 parser.emitError(parser.getNameLoc(), "expected LEN parameter list");
925 return {};
926 }
927 lenParamList.emplace_back(lenparam, intTy);
928 if (parser.parseOptionalComma())
929 break;
930 }
931 if (parser.parseRParen())
932 return {};
933 }
934
935 RecordType::TypeList typeList;
936 if (!parser.parseOptionalLBrace()) {
937 while (true) {
938 llvm::StringRef field;
939 mlir::Type fldTy;
940 if (parser.parseKeyword(&field) || parser.parseColon() ||
941 parser.parseType(fldTy)) {
942 parser.emitError(parser.getNameLoc(), "expected field type list");
943 return {};
944 }
945 typeList.emplace_back(field, fldTy);
946 if (parser.parseOptionalComma())
947 break;
948 }
949 if (parser.parseRBrace())
950 return {};
951 }
952
953 if (parser.parseGreater())
954 return {};
955
956 if (lenParamList.empty() && typeList.empty())
957 return result;
958
959 result.finalize(lenParamList, typeList);
960 return verifyDerived(parser, result, lenParamList, typeList);
961}
962
963void fir::RecordType::print(mlir::AsmPrinter &printer) const {
964 printer << "<" << getName();
965 if (!recordTypeVisited.count(uniqueKey())) {
966 recordTypeVisited.insert(uniqueKey());
967 if (getLenParamList().size()) {
968 char ch = '(';
969 for (auto p : getLenParamList()) {
970 printer << ch << p.first << ':';
971 p.second.print(printer.getStream());
972 ch = ',';
973 }
974 printer << ')';
975 }
976 if (getTypeList().size()) {
977 char ch = '{';
978 for (auto p : getTypeList()) {
979 printer << ch << p.first << ':';
980 p.second.print(printer.getStream());
981 ch = ',';
982 }
983 printer << '}';
984 }
985 recordTypeVisited.erase(uniqueKey());
986 }
987 printer << '>';
988}
989
990void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
991 llvm::ArrayRef<TypePair> typeList) {
992 getImpl()->finalize(lenPList, typeList);
993}
994
995llvm::StringRef fir::RecordType::getName() const {
996 return getImpl()->getName();
997}
998
999RecordType::TypeList fir::RecordType::getTypeList() const {
1000 return getImpl()->getTypeList();
1001}
1002
1003RecordType::TypeList fir::RecordType::getLenParamList() const {
1004 return getImpl()->getLenParamList();
1005}
1006
1007bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); }
1008
1009detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
1010 return getImpl();
1011}
1012
1013mlir::LogicalResult fir::RecordType::verify(
1014 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1015 llvm::StringRef name) {
1016 if (name.size() == 0)
1017 return emitError() << "record types must have a name";
1018 return mlir::success();
1019}
1020
1021mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
1022 for (auto f : getTypeList())
1023 if (ident == f.first)
1024 return f.second;
1025 return {};
1026}
1027
1028unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
1029 for (auto f : llvm::enumerate(getTypeList()))
1030 if (ident == f.value().first)
1031 return f.index();
1032 return std::numeric_limits<unsigned>::max();
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// ReferenceType
1037//===----------------------------------------------------------------------===//
1038
1039// `ref` `<` type `>`
1040mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
1041 return parseTypeSingleton<fir::ReferenceType>(parser);
1042}
1043
1044void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
1045 printer << "<" << getEleTy() << '>';
1046}
1047
1048mlir::LogicalResult fir::ReferenceType::verify(
1049 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1050 mlir::Type eleTy) {
1051 if (eleTy.isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
1052 ReferenceType, TypeDescType>())
1053 return emitError() << "cannot build a reference to type: " << eleTy << '\n';
1054 return mlir::success();
1055}
1056
1057//===----------------------------------------------------------------------===//
1058// SequenceType
1059//===----------------------------------------------------------------------===//
1060
1061// `array` `<` `*` | bounds (`x` bounds)* `:` type (',' affine-map)? `>`
1062// bounds ::= `?` | int-lit
1063mlir::Type fir::SequenceType::parse(mlir::AsmParser &parser) {
1064 if (parser.parseLess())
1065 return {};
1066 SequenceType::Shape shape;
1067 if (parser.parseOptionalStar()) {
1068 if (parser.parseDimensionList(shape, /*allowDynamic=*/true))
1069 return {};
1070 } else if (parser.parseColon()) {
1071 return {};
1072 }
1073 mlir::Type eleTy;
1074 if (parser.parseType(eleTy))
1075 return {};
1076 mlir::AffineMapAttr map;
1077 if (!parser.parseOptionalComma()) {
1078 if (parser.parseAttribute(map)) {
1079 parser.emitError(parser.getNameLoc(), "expecting affine map");
1080 return {};
1081 }
1082 }
1083 if (parser.parseGreater())
1084 return {};
1085 return SequenceType::get(parser.getContext(), shape, eleTy, map);
1086}
1087
1088void fir::SequenceType::print(mlir::AsmPrinter &printer) const {
1089 auto shape = getShape();
1090 if (shape.size()) {
1091 printer << '<';
1092 for (const auto &b : shape) {
1093 if (b >= 0)
1094 printer << b << 'x';
1095 else
1096 printer << "?x";
1097 }
1098 } else {
1099 printer << "<*:";
1100 }
1101 printer << getEleTy();
1102 if (auto map = getLayoutMap()) {
1103 printer << ", ";
1104 map.print(printer.getStream());
1105 }
1106 printer << '>';
1107}
1108
1109unsigned fir::SequenceType::getConstantRows() const {
1110 if (hasDynamicSize(getEleTy()))
1111 return 0;
1112 auto shape = getShape();
1113 unsigned count = 0;
1114 for (auto d : shape) {
1115 if (d == getUnknownExtent())
1116 break;
1117 ++count;
1118 }
1119 return count;
1120}
1121
1122mlir::LogicalResult fir::SequenceType::verify(
1123 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1124 llvm::ArrayRef<int64_t> shape, mlir::Type eleTy,
1125 mlir::AffineMapAttr layoutMap) {
1126 // DIMENSION attribute can only be applied to an intrinsic or record type
1127 if (eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
1128 ShiftType, SliceType, FieldType, LenType, HeapType, PointerType,
1129 ReferenceType, TypeDescType, SequenceType>())
1130 return emitError() << "cannot build an array of this element type: "
1131 << eleTy << '\n';
1132 return mlir::success();
1133}
1134
1135//===----------------------------------------------------------------------===//
1136// ShapeType
1137//===----------------------------------------------------------------------===//
1138
1139mlir::Type fir::ShapeType::parse(mlir::AsmParser &parser) {
1140 return parseRankSingleton<fir::ShapeType>(parser);
1141}
1142
1143void fir::ShapeType::print(mlir::AsmPrinter &printer) const {
1144 printer << "<" << getImpl()->rank << ">";
1145}
1146
1147//===----------------------------------------------------------------------===//
1148// ShapeShiftType
1149//===----------------------------------------------------------------------===//
1150
1151mlir::Type fir::ShapeShiftType::parse(mlir::AsmParser &parser) {
1152 return parseRankSingleton<fir::ShapeShiftType>(parser);
1153}
1154
1155void fir::ShapeShiftType::print(mlir::AsmPrinter &printer) const {
1156 printer << "<" << getRank() << ">";
1157}
1158
1159//===----------------------------------------------------------------------===//
1160// ShiftType
1161//===----------------------------------------------------------------------===//
1162
1163mlir::Type fir::ShiftType::parse(mlir::AsmParser &parser) {
1164 return parseRankSingleton<fir::ShiftType>(parser);
1165}
1166
1167void fir::ShiftType::print(mlir::AsmPrinter &printer) const {
1168 printer << "<" << getRank() << ">";
1169}
1170
1171//===----------------------------------------------------------------------===//
1172// SliceType
1173//===----------------------------------------------------------------------===//
1174
1175// `slice` `<` rank `>`
1176mlir::Type fir::SliceType::parse(mlir::AsmParser &parser) {
1177 return parseRankSingleton<fir::SliceType>(parser);
1178}
1179
1180void fir::SliceType::print(mlir::AsmPrinter &printer) const {
1181 printer << "<" << getRank() << '>';
1182}
1183
1184//===----------------------------------------------------------------------===//
1185// TypeDescType
1186//===----------------------------------------------------------------------===//
1187
1188// `tdesc` `<` type `>`
1189mlir::Type fir::TypeDescType::parse(mlir::AsmParser &parser) {
1190 return parseTypeSingleton<fir::TypeDescType>(parser);
1191}
1192
1193void fir::TypeDescType::print(mlir::AsmPrinter &printer) const {
1194 printer << "<" << getOfTy() << '>';
1195}
1196
1197mlir::LogicalResult fir::TypeDescType::verify(
1198 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1199 mlir::Type eleTy) {
1200 if (eleTy.isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
1201 ShiftType, SliceType, FieldType, LenType, ReferenceType,
1202 TypeDescType>())
1203 return emitError() << "cannot build a type descriptor of type: " << eleTy
1204 << '\n';
1205 return mlir::success();
1206}
1207
1208//===----------------------------------------------------------------------===//
1209// VectorType
1210//===----------------------------------------------------------------------===//
1211
1212// `vector` `<` len `:` type `>`
1213mlir::Type fir::VectorType::parse(mlir::AsmParser &parser) {
1214 int64_t len = 0;
1215 mlir::Type eleTy;
1216 if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() ||
1217 parser.parseType(eleTy) || parser.parseGreater())
1218 return {};
1219 return fir::VectorType::get(len, eleTy);
1220}
1221
1222void fir::VectorType::print(mlir::AsmPrinter &printer) const {
1223 printer << "<" << getLen() << ':' << getEleTy() << '>';
1224}
1225
1226mlir::LogicalResult fir::VectorType::verify(
1227 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
1228 mlir::Type eleTy) {
1229 if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
1230 return emitError() << "cannot build a vector of type " << eleTy << '\n';
1231 return mlir::success();
1232}
1233
1234bool fir::VectorType::isValidElementType(mlir::Type t) {
1235 return isa_real(t) || isa_integer(t);
1236}
1237
1238bool fir::isCharacterProcedureTuple(mlir::Type ty, bool acceptRawFunc) {
1239 mlir::TupleType tuple = ty.dyn_cast<mlir::TupleType>();
1240 return tuple && tuple.size() == 2 &&
1241 (tuple.getType(0).isa<fir::BoxProcType>() ||
1242 (acceptRawFunc && tuple.getType(0).isa<mlir::FunctionType>())) &&
1243 fir::isa_integer(tuple.getType(1));
1244}
1245
1246bool fir::hasAbstractResult(mlir::FunctionType ty) {
1247 if (ty.getNumResults() == 0)
1248 return false;
1249 auto resultType = ty.getResult(0);
1250 return resultType.isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>();
1251}
1252
1253/// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error
1254/// messages only.
1255mlir::Type fir::fromRealTypeID(mlir::MLIRContext *context,
1256 llvm::Type::TypeID typeID, fir::KindTy kind) {
1257 switch (typeID) {
1258 case llvm::Type::TypeID::HalfTyID:
1259 return mlir::FloatType::getF16(ctx: context);
1260 case llvm::Type::TypeID::BFloatTyID:
1261 return mlir::FloatType::getBF16(ctx: context);
1262 case llvm::Type::TypeID::FloatTyID:
1263 return mlir::FloatType::getF32(ctx: context);
1264 case llvm::Type::TypeID::DoubleTyID:
1265 return mlir::FloatType::getF64(ctx: context);
1266 case llvm::Type::TypeID::X86_FP80TyID:
1267 return mlir::FloatType::getF80(ctx: context);
1268 case llvm::Type::TypeID::FP128TyID:
1269 return mlir::FloatType::getF128(ctx: context);
1270 default:
1271 mlir::emitError(mlir::UnknownLoc::get(context))
1272 << "unsupported type: !fir.real<" << kind << ">";
1273 return {};
1274 }
1275}
1276
1277//===----------------------------------------------------------------------===//
1278// BaseBoxType
1279//===----------------------------------------------------------------------===//
1280
1281mlir::Type BaseBoxType::getEleTy() const {
1282 return llvm::TypeSwitch<fir::BaseBoxType, mlir::Type>(*this)
1283 .Case<fir::BoxType, fir::ClassType>(
1284 [](auto type) { return type.getEleTy(); });
1285}
1286
1287mlir::Type BaseBoxType::unwrapInnerType() const {
1288 return fir::unwrapInnerType(getEleTy());
1289}
1290
1291static mlir::Type
1292changeTypeShape(mlir::Type type,
1293 std::optional<fir::SequenceType::ShapeRef> newShape) {
1294 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
1295 .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
1296 if (newShape)
1297 return fir::SequenceType::get(*newShape, seqTy.getEleTy());
1298 return seqTy.getEleTy();
1299 })
1300 .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
1301 fir::ClassType>([&](auto t) -> mlir::Type {
1302 using FIRT = decltype(t);
1303 return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
1304 })
1305 .Default([&](mlir::Type t) -> mlir::Type {
1306 assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
1307 llvm::isa<mlir::NoneType>(t)) &&
1308 "unexpected FIR leaf type");
1309 if (newShape)
1310 return fir::SequenceType::get(*newShape, t);
1311 return t;
1312 });
1313}
1314
1315fir::BaseBoxType
1316fir::BaseBoxType::getBoxTypeWithNewShape(mlir::Type shapeMold) const {
1317 fir::SequenceType seqTy = fir::unwrapUntilSeqType(shapeMold);
1318 std::optional<fir::SequenceType::ShapeRef> newShape;
1319 if (seqTy)
1320 newShape = seqTy.getShape();
1321 return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
1322}
1323
1324bool fir::BaseBoxType::isAssumedRank() const {
1325 if (auto seqTy =
1326 mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(getEleTy())))
1327 return seqTy.hasUnknownShape();
1328 return false;
1329}
1330
1331//===----------------------------------------------------------------------===//
1332// FIROpsDialect
1333//===----------------------------------------------------------------------===//
1334
1335void FIROpsDialect::registerTypes() {
1336 addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, ClassType,
1337 fir::ComplexType, FieldType, HeapType, fir::IntegerType, LenType,
1338 LogicalType, LLVMPointerType, PointerType, RealType, RecordType,
1339 ReferenceType, SequenceType, ShapeType, ShapeShiftType, ShiftType,
1340 SliceType, TypeDescType, fir::VectorType>();
1341 fir::ReferenceType::attachInterface<
1342 OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext());
1343 fir::ReferenceType::attachInterface<
1344 OpenACCPointerLikeModel<fir::ReferenceType>>(*getContext());
1345
1346 fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>(
1347 *getContext());
1348 fir::PointerType::attachInterface<OpenACCPointerLikeModel<fir::PointerType>>(
1349 *getContext());
1350
1351 fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>(
1352 *getContext());
1353 fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
1354 *getContext());
1355
1356 fir::LLVMPointerType::attachInterface<
1357 OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext());
1358 fir::LLVMPointerType::attachInterface<
1359 OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
1360}
1361
1362std::pair<std::uint64_t, unsigned short>
1363fir::getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
1364 const mlir::DataLayout &dl,
1365 const fir::KindMapping &kindMap) {
1366 if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
1367 llvm::TypeSize size = dl.getTypeSize(ty);
1368 unsigned short alignment = dl.getTypeABIAlignment(ty);
1369 return {size, alignment};
1370 }
1371 if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
1372 auto [floatSize, floatAlign] =
1373 getTypeSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
1374 return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
1375 }
1376 if (auto real = mlir::dyn_cast<fir::RealType>(ty))
1377 return getTypeSizeAndAlignment(loc, real.getFloatType(kindMap), dl,
1378 kindMap);
1379
1380 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1381 auto [eleSize, eleAlign] =
1382 getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
1383
1384 std::uint64_t size =
1385 llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
1386 return {size, eleAlign};
1387 }
1388 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
1389 std::uint64_t size = 0;
1390 unsigned short align = 1;
1391 for (auto component : recTy.getTypeList()) {
1392 auto [compSize, compAlign] =
1393 getTypeSizeAndAlignment(loc, component.second, dl, kindMap);
1394 size =
1395 llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
1396 align = std::max(align, compAlign);
1397 }
1398 return {size, align};
1399 }
1400 if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
1401 mlir::Type intTy = mlir::IntegerType::get(
1402 logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
1403 return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1404 }
1405 if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
1406 mlir::Type intTy = mlir::IntegerType::get(
1407 character.getContext(),
1408 kindMap.getCharacterBitsize(character.getFKind()));
1409 return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1410 }
1411 TODO(loc, "computing size of a component");
1412}
1413

source code of flang/lib/Optimizer/Dialect/FIRType.cpp