1//===-- FIRType.cpp -------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/Dialect/FIRType.h"
14#include "flang/Common/ISO_Fortran_binding_wrapper.h"
15#include "flang/Optimizer/Builder/Todo.h"
16#include "flang/Optimizer/Dialect/FIRDialect.h"
17#include "flang/Optimizer/Dialect/Support/KindMapping.h"
18#include "flang/Tools/PointerModels.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/ErrorHandling.h"
28
29#define GET_TYPEDEF_CLASSES
30#include "flang/Optimizer/Dialect/FIROpsTypes.cpp.inc"
31
32using namespace fir;
33
34namespace {
35
36template <typename TYPE>
37TYPE parseIntSingleton(mlir::AsmParser &parser) {
38 int kind = 0;
39 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseGreater())
40 return {};
41 return TYPE::get(parser.getContext(), kind);
42}
43
44template <typename TYPE>
45TYPE parseKindSingleton(mlir::AsmParser &parser) {
46 return parseIntSingleton<TYPE>(parser);
47}
48
49template <typename TYPE>
50TYPE parseRankSingleton(mlir::AsmParser &parser) {
51 return parseIntSingleton<TYPE>(parser);
52}
53
54template <typename TYPE>
55TYPE parseTypeSingleton(mlir::AsmParser &parser) {
56 mlir::Type ty;
57 if (parser.parseLess() || parser.parseType(result&: ty) || parser.parseGreater())
58 return {};
59 return TYPE::get(ty);
60}
61
62/// Is `ty` a standard or FIR integer type?
63static bool isaIntegerType(mlir::Type ty) {
64 // TODO: why aren't we using isa_integer? investigatation required.
65 return mlir::isa<mlir::IntegerType, fir::IntegerType>(ty);
66}
67
68bool verifyRecordMemberType(mlir::Type ty) {
69 return !mlir::isa<BoxCharType, ShapeType, ShapeShiftType, ShiftType,
70 SliceType, FieldType, LenType, ReferenceType, TypeDescType>(
71 ty);
72}
73
74bool verifySameLists(llvm::ArrayRef<RecordType::TypePair> a1,
75 llvm::ArrayRef<RecordType::TypePair> a2) {
76 // FIXME: do we need to allow for any variance here?
77 return a1 == a2;
78}
79
80static llvm::StringRef getVolatileKeyword() { return "volatile"; }
81
82static mlir::ParseResult parseOptionalCommaAndKeyword(mlir::AsmParser &parser,
83 mlir::StringRef keyword,
84 bool &parsedKeyword) {
85 if (!parser.parseOptionalComma()) {
86 if (parser.parseKeyword(keyword))
87 return mlir::failure();
88 parsedKeyword = true;
89 return mlir::success();
90 }
91 parsedKeyword = false;
92 return mlir::success();
93}
94
95RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
96 llvm::ArrayRef<RecordType::TypePair> lenPList,
97 llvm::ArrayRef<RecordType::TypePair> typeList) {
98 auto loc = parser.getNameLoc();
99 if (!verifySameLists(derivedTy.getLenParamList(), lenPList) ||
100 !verifySameLists(derivedTy.getTypeList(), typeList)) {
101 parser.emitError(loc, message: "cannot redefine record type members");
102 return {};
103 }
104 for (auto &p : lenPList)
105 if (!isaIntegerType(p.second)) {
106 parser.emitError(loc, "LEN parameter must be integral type");
107 return {};
108 }
109 for (auto &p : typeList)
110 if (!verifyRecordMemberType(p.second)) {
111 parser.emitError(loc, "field parameter has invalid type");
112 return {};
113 }
114 llvm::StringSet<> uniq;
115 for (auto &p : lenPList)
116 if (!uniq.insert(p.first).second) {
117 parser.emitError(loc, "LEN parameter cannot have duplicate name");
118 return {};
119 }
120 for (auto &p : typeList)
121 if (!uniq.insert(p.first).second) {
122 parser.emitError(loc, "field cannot have duplicate name");
123 return {};
124 }
125 return derivedTy;
126}
127
128} // namespace
129
130// Implementation of the thin interface from dialect to type parser
131
132mlir::Type fir::parseFirType(FIROpsDialect *dialect,
133 mlir::DialectAsmParser &parser) {
134 mlir::StringRef typeTag;
135 mlir::Type genType;
136 auto parseResult = generatedTypeParser(parser, &typeTag, genType);
137 if (parseResult.has_value())
138 return genType;
139 parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;
140 return {};
141}
142
143namespace fir {
144namespace detail {
145
146// Type storage classes
147
148/// Derived type storage
149struct RecordTypeStorage : public mlir::TypeStorage {
150 using KeyTy = llvm::StringRef;
151
152 static unsigned hashKey(const KeyTy &key) {
153 return llvm::hash_combine(args: key.str());
154 }
155
156 bool operator==(const KeyTy &key) const { return key == getName(); }
157
158 static RecordTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
159 const KeyTy &key) {
160 auto *storage = allocator.allocate<RecordTypeStorage>();
161 return new (storage) RecordTypeStorage{key};
162 }
163
164 llvm::StringRef getName() const { return name; }
165
166 void setLenParamList(llvm::ArrayRef<RecordType::TypePair> list) {
167 lens = list;
168 }
169 llvm::ArrayRef<RecordType::TypePair> getLenParamList() const { return lens; }
170
171 void setTypeList(llvm::ArrayRef<RecordType::TypePair> list) { types = list; }
172 llvm::ArrayRef<RecordType::TypePair> getTypeList() const { return types; }
173
174 bool isFinalized() const { return finalized; }
175 void finalize(llvm::ArrayRef<RecordType::TypePair> lenParamList,
176 llvm::ArrayRef<RecordType::TypePair> typeList) {
177 if (finalized)
178 return;
179 finalized = true;
180 setLenParamList(lenParamList);
181 setTypeList(typeList);
182 }
183
184 bool isPacked() const { return packed; }
185 void pack(bool p) { packed = p; }
186
187protected:
188 std::string name;
189 bool finalized;
190 bool packed;
191 std::vector<RecordType::TypePair> lens;
192 std::vector<RecordType::TypePair> types;
193
194private:
195 RecordTypeStorage() = delete;
196 explicit RecordTypeStorage(llvm::StringRef name)
197 : name{name}, finalized{false}, packed{false} {}
198};
199
200} // namespace detail
201
202template <typename A, typename B>
203bool inbounds(A v, B lb, B ub) {
204 return v >= lb && v < ub;
205}
206
207bool isa_fir_type(mlir::Type t) {
208 return llvm::isa<FIROpsDialect>(t.getDialect());
209}
210
211bool isa_std_type(mlir::Type t) {
212 return llvm::isa<mlir::BuiltinDialect>(Val: t.getDialect());
213}
214
215bool isa_fir_or_std_type(mlir::Type t) {
216 if (auto funcType = mlir::dyn_cast<mlir::FunctionType>(Val&: t))
217 return llvm::all_of(Range: funcType.getInputs(), P: isa_fir_or_std_type) &&
218 llvm::all_of(Range: funcType.getResults(), P: isa_fir_or_std_type);
219 return isa_fir_type(t) || isa_std_type(t);
220}
221
222mlir::Type getDerivedType(mlir::Type ty) {
223 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty)
224 .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto p) {
225 if (auto seq = mlir::dyn_cast<fir::SequenceType>(p.getEleTy()))
226 return seq.getEleTy();
227 return p.getEleTy();
228 })
229 .Case<fir::BaseBoxType>(
230 [](auto p) { return getDerivedType(p.getEleTy()); })
231 .Default([](mlir::Type t) { return t; });
232}
233
234mlir::Type updateTypeWithVolatility(mlir::Type type, bool isVolatile) {
235 // If we already have the volatility we asked for, return the type unchanged.
236 if (fir::isa_volatile_type(type) == isVolatile)
237 return type;
238 return mlir::TypeSwitch<mlir::Type, mlir::Type>(type)
239 .Case<fir::BoxType, fir::ClassType, fir::ReferenceType>(
240 [&](auto ty) -> mlir::Type {
241 using TYPE = decltype(ty);
242 return TYPE::get(ty.getEleTy(), isVolatile);
243 })
244 .Default([&](mlir::Type t) -> mlir::Type { return t; });
245}
246
247mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
248 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
249 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
250 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
251 .Default([](mlir::Type) { return mlir::Type{}; });
252}
253
254mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
255 return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
256 .Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
257 fir::LLVMPointerType>([](auto p) { return p.getEleTy(); })
258 .Case<fir::BaseBoxType, fir::BoxCharType>(
259 [](auto p) { return unwrapRefType(p.getEleTy()); })
260 .Default([](mlir::Type) { return mlir::Type{}; });
261}
262
263static bool hasDynamicSize(fir::RecordType recTy) {
264 if (recTy.getLenParamList().empty())
265 return false;
266 for (auto field : recTy.getTypeList()) {
267 if (auto arr = mlir::dyn_cast<fir::SequenceType>(field.second)) {
268 if (sequenceWithNonConstantShape(arr))
269 return true;
270 } else if (characterWithDynamicLen(field.second)) {
271 return true;
272 } else if (auto rec = mlir::dyn_cast<fir::RecordType>(field.second)) {
273 if (hasDynamicSize(rec))
274 return true;
275 }
276 }
277 return false;
278}
279
280bool hasDynamicSize(mlir::Type t) {
281 if (auto arr = mlir::dyn_cast<fir::SequenceType>(t)) {
282 if (sequenceWithNonConstantShape(arr))
283 return true;
284 t = arr.getEleTy();
285 }
286 if (characterWithDynamicLen(t))
287 return true;
288 if (auto rec = mlir::dyn_cast<fir::RecordType>(t))
289 return hasDynamicSize(rec);
290 return false;
291}
292
293mlir::Type extractSequenceType(mlir::Type ty) {
294 if (mlir::isa<fir::SequenceType>(ty))
295 return ty;
296 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
297 return extractSequenceType(boxTy.getEleTy());
298 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
299 return extractSequenceType(heapTy.getEleTy());
300 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
301 return extractSequenceType(ptrTy.getEleTy());
302 return mlir::Type{};
303}
304
305bool isPointerType(mlir::Type ty) {
306 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
307 ty = refTy;
308 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
309 return mlir::isa<fir::PointerType>(boxTy.getEleTy());
310 return false;
311}
312
313bool isAllocatableType(mlir::Type ty) {
314 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
315 ty = refTy;
316 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
317 return mlir::isa<fir::HeapType>(boxTy.getEleTy());
318 return false;
319}
320
321bool isBoxNone(mlir::Type ty) {
322 if (auto box = mlir::dyn_cast<fir::BoxType>(ty))
323 return mlir::isa<mlir::NoneType>(box.getEleTy());
324 return false;
325}
326
327bool isBoxedRecordType(mlir::Type ty) {
328 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
329 ty = refTy;
330 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
331 if (mlir::isa<fir::RecordType>(boxTy.getEleTy()))
332 return true;
333 mlir::Type innerType = boxTy.unwrapInnerType();
334 return innerType && mlir::isa<fir::RecordType>(innerType);
335 }
336 return false;
337}
338
339bool isScalarBoxedRecordType(mlir::Type ty) {
340 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
341 ty = refTy;
342 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
343 if (mlir::isa<fir::RecordType>(boxTy.getEleTy()))
344 return true;
345 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
346 return mlir::isa<fir::RecordType>(heapTy.getEleTy());
347 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
348 return mlir::isa<fir::RecordType>(ptrTy.getEleTy());
349 }
350 return false;
351}
352
353bool isAssumedType(mlir::Type ty) {
354 // Rule out CLASS(*) which are `fir.class<[fir.array] none>`.
355 if (mlir::isa<fir::ClassType>(ty))
356 return false;
357 mlir::Type valueType = fir::unwrapPassByRefType(fir::unwrapRefType(ty));
358 // Refuse raw `none` or `fir.array<none>` since assumed type
359 // should be in memory variables.
360 if (valueType == ty)
361 return false;
362 mlir::Type inner = fir::unwrapSequenceType(valueType);
363 return mlir::isa<mlir::NoneType>(Val: inner);
364}
365
366bool isAssumedShape(mlir::Type ty) {
367 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty))
368 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
369 return seqTy.hasDynamicExtents();
370 return false;
371}
372
373bool isAllocatableOrPointerArray(mlir::Type ty) {
374 if (auto refTy = fir::dyn_cast_ptrEleTy(t: ty))
375 ty = refTy;
376 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
377 if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
378 return mlir::isa<fir::SequenceType>(heapTy.getEleTy());
379 if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
380 return mlir::isa<fir::SequenceType>(ptrTy.getEleTy());
381 }
382 return false;
383}
384
385bool isTypeWithDescriptor(mlir::Type ty) {
386 if (mlir::isa<fir::BaseBoxType>(unwrapRefType(ty)))
387 return true;
388 return false;
389}
390
391bool isPolymorphicType(mlir::Type ty) {
392 // CLASS(T) or CLASS(*)
393 if (mlir::isa<fir::ClassType>(fir::unwrapRefType(ty)))
394 return true;
395 // assumed type are polymorphic.
396 return isAssumedType(ty);
397}
398
399bool isUnlimitedPolymorphicType(mlir::Type ty) {
400 // CLASS(*)
401 if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) {
402 if (mlir::isa<mlir::NoneType>(clTy.getEleTy()))
403 return true;
404 mlir::Type innerType = clTy.unwrapInnerType();
405 return innerType && mlir::isa<mlir::NoneType>(Val: innerType);
406 }
407 // TYPE(*)
408 return isAssumedType(ty);
409}
410
411mlir::Type unwrapInnerType(mlir::Type ty) {
412 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ty)
413 .Case<fir::PointerType, fir::HeapType, fir::SequenceType>([](auto t) {
414 mlir::Type eleTy = t.getEleTy();
415 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
416 return seqTy.getEleTy();
417 return eleTy;
418 })
419 .Case<fir::RecordType>([](auto t) { return t; })
420 .Default([](mlir::Type) { return mlir::Type{}; });
421}
422
423bool isRecordWithAllocatableMember(mlir::Type ty) {
424 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty))
425 for (auto [field, memTy] : recTy.getTypeList()) {
426 if (fir::isAllocatableType(memTy))
427 return true;
428 // A record type cannot recursively include itself as a direct member.
429 // There must be an intervening `ptr` type, so recursion is safe here.
430 if (mlir::isa<fir::RecordType>(memTy) &&
431 isRecordWithAllocatableMember(memTy))
432 return true;
433 }
434 return false;
435}
436
437bool isRecordWithDescriptorMember(mlir::Type ty) {
438 ty = unwrapSequenceType(ty);
439 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty))
440 for (auto [field, memTy] : recTy.getTypeList()) {
441 memTy = unwrapSequenceType(memTy);
442 if (mlir::isa<fir::BaseBoxType>(memTy))
443 return true;
444 if (mlir::isa<fir::RecordType>(memTy) &&
445 isRecordWithDescriptorMember(memTy))
446 return true;
447 }
448 return false;
449}
450
451mlir::Type unwrapAllRefAndSeqType(mlir::Type ty) {
452 while (true) {
453 mlir::Type nt = unwrapSequenceType(unwrapRefType(ty));
454 if (auto vecTy = mlir::dyn_cast<fir::VectorType>(nt))
455 nt = vecTy.getEleTy();
456 if (nt == ty)
457 return ty;
458 ty = nt;
459 }
460}
461
462mlir::Type getFortranElementType(mlir::Type ty) {
463 return fir::unwrapSequenceType(
464 fir::unwrapPassByRefType(fir::unwrapRefType(ty)));
465}
466
467mlir::Type unwrapSeqOrBoxedSeqType(mlir::Type ty) {
468 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
469 return seqTy.getEleTy();
470 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
471 auto eleTy = unwrapRefType(boxTy.getEleTy());
472 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
473 return seqTy.getEleTy();
474 }
475 return ty;
476}
477
478unsigned getBoxRank(mlir::Type boxTy) {
479 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(t: boxTy);
480 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
481 return seqTy.getDimension();
482 return 0;
483}
484
485/// Return the ISO_C_BINDING intrinsic module value of type \p ty.
486int getTypeCode(mlir::Type ty, const fir::KindMapping &kindMap) {
487 if (mlir::IntegerType intTy = mlir::dyn_cast<mlir::IntegerType>(Val&: ty)) {
488 if (intTy.isUnsigned()) {
489 switch (intTy.getWidth()) {
490 case 8:
491 return CFI_type_uint8_t;
492 case 16:
493 return CFI_type_uint16_t;
494 case 32:
495 return CFI_type_uint32_t;
496 case 64:
497 return CFI_type_uint64_t;
498 case 128:
499 return CFI_type_uint128_t;
500 }
501 llvm_unreachable("unsupported integer type");
502 } else {
503 switch (intTy.getWidth()) {
504 case 8:
505 return CFI_type_int8_t;
506 case 16:
507 return CFI_type_int16_t;
508 case 32:
509 return CFI_type_int32_t;
510 case 64:
511 return CFI_type_int64_t;
512 case 128:
513 return CFI_type_int128_t;
514 }
515 llvm_unreachable("unsupported integer type");
516 }
517 }
518 if (fir::LogicalType logicalTy = mlir::dyn_cast<fir::LogicalType>(ty)) {
519 switch (kindMap.getLogicalBitsize(logicalTy.getFKind())) {
520 case 8:
521 return CFI_type_Bool;
522 case 16:
523 return CFI_type_int_least16_t;
524 case 32:
525 return CFI_type_int_least32_t;
526 case 64:
527 return CFI_type_int_least64_t;
528 }
529 llvm_unreachable("unsupported logical type");
530 }
531 if (mlir::FloatType floatTy = mlir::dyn_cast<mlir::FloatType>(Val&: ty)) {
532 switch (floatTy.getWidth()) {
533 case 16:
534 return floatTy.isBF16() ? CFI_type_bfloat : CFI_type_half_float;
535 case 32:
536 return CFI_type_float;
537 case 64:
538 return CFI_type_double;
539 case 80:
540 return CFI_type_extended_double;
541 case 128:
542 return CFI_type_float128;
543 }
544 llvm_unreachable("unsupported real type");
545 }
546 if (mlir::ComplexType complexTy = mlir::dyn_cast<mlir::ComplexType>(Val&: ty)) {
547 mlir::FloatType floatTy =
548 mlir::cast<mlir::FloatType>(Val: complexTy.getElementType());
549 if (floatTy.isBF16())
550 return CFI_type_bfloat_Complex;
551 switch (floatTy.getWidth()) {
552 case 16:
553 return CFI_type_half_float_Complex;
554 case 32:
555 return CFI_type_float_Complex;
556 case 64:
557 return CFI_type_double_Complex;
558 case 80:
559 return CFI_type_extended_double_Complex;
560 case 128:
561 return CFI_type_float128_Complex;
562 }
563 llvm_unreachable("unsupported complex size");
564 }
565 if (fir::CharacterType charTy = mlir::dyn_cast<fir::CharacterType>(ty)) {
566 switch (kindMap.getCharacterBitsize(charTy.getFKind())) {
567 case 8:
568 return CFI_type_char;
569 case 16:
570 return CFI_type_char16_t;
571 case 32:
572 return CFI_type_char32_t;
573 }
574 llvm_unreachable("unsupported character type");
575 }
576 if (fir::isa_ref_type(ty))
577 return CFI_type_cptr;
578 if (mlir::isa<fir::RecordType>(ty))
579 return CFI_type_struct;
580 llvm_unreachable("unsupported type");
581}
582
583std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap,
584 llvm::StringRef prefix) {
585 std::string buf = prefix.str();
586 llvm::raw_string_ostream name{buf};
587 if (!prefix.empty())
588 name << "_";
589
590 std::function<void(mlir::Type)> appendTypeName = [&](mlir::Type ty) {
591 while (ty) {
592 if (fir::isa_trivial(ty)) {
593 if (mlir::isa<mlir::IndexType>(Val: ty)) {
594 name << "idx";
595 } else if (ty.isIntOrIndex()) {
596 name << 'i' << ty.getIntOrFloatBitWidth();
597 } else if (mlir::isa<mlir::FloatType>(Val: ty)) {
598 name << 'f' << ty.getIntOrFloatBitWidth();
599 } else if (auto cplxTy =
600 mlir::dyn_cast_or_null<mlir::ComplexType>(Val&: ty)) {
601 name << 'z';
602 auto floatTy = mlir::cast<mlir::FloatType>(Val: cplxTy.getElementType());
603 name << floatTy.getWidth();
604 } else if (auto logTy = mlir::dyn_cast_or_null<fir::LogicalType>(ty)) {
605 name << 'l' << kindMap.getLogicalBitsize(logTy.getFKind());
606 } else {
607 llvm::report_fatal_error(reason: "unsupported type");
608 }
609 break;
610 } else if (mlir::isa<mlir::NoneType>(Val: ty)) {
611 name << "none";
612 break;
613 } else if (auto charTy = mlir::dyn_cast_or_null<fir::CharacterType>(ty)) {
614 name << 'c' << kindMap.getCharacterBitsize(charTy.getFKind());
615 if (charTy.getLen() == fir::CharacterType::unknownLen())
616 name << "xU";
617 else if (charTy.getLen() != fir::CharacterType::singleton())
618 name << "x" << charTy.getLen();
619 break;
620 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) {
621 for (auto extent : seqTy.getShape()) {
622 if (extent == fir::SequenceType::getUnknownExtent())
623 name << "Ux";
624 else
625 name << extent << 'x';
626 }
627 ty = seqTy.getEleTy();
628 } else if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
629 name << "ref_";
630 ty = refTy.getEleTy();
631 } else if (auto ptrTy = mlir::dyn_cast_or_null<fir::PointerType>(ty)) {
632 name << "ptr_";
633 ty = ptrTy.getEleTy();
634 } else if (auto ptrTy =
635 mlir::dyn_cast_or_null<fir::LLVMPointerType>(ty)) {
636 name << "llvmptr_";
637 ty = ptrTy.getEleTy();
638 } else if (auto heapTy = mlir::dyn_cast_or_null<fir::HeapType>(ty)) {
639 name << "heap_";
640 ty = heapTy.getEleTy();
641 } else if (auto classTy = mlir::dyn_cast_or_null<fir::ClassType>(ty)) {
642 name << "class_";
643 ty = classTy.getEleTy();
644 } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BoxType>(ty)) {
645 name << "box_";
646 ty = boxTy.getEleTy();
647 } else if (auto boxcharTy =
648 mlir::dyn_cast_or_null<fir::BoxCharType>(ty)) {
649 name << "boxchar_";
650 ty = boxcharTy.getEleTy();
651 } else if (auto boxprocTy =
652 mlir::dyn_cast_or_null<fir::BoxProcType>(ty)) {
653 name << "boxproc_";
654 auto procTy = mlir::dyn_cast<mlir::FunctionType>(boxprocTy.getEleTy());
655 assert(procTy.getNumResults() <= 1 &&
656 "function type with more than one result");
657 for (const auto &result : procTy.getResults())
658 appendTypeName(result);
659 name << "_args";
660 for (const auto &arg : procTy.getInputs()) {
661 name << '_';
662 appendTypeName(arg);
663 }
664 break;
665 } else if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(ty)) {
666 name << "rec_" << recTy.getName();
667 break;
668 } else {
669 llvm::report_fatal_error(reason: "unsupported type");
670 }
671 }
672 };
673
674 appendTypeName(ty);
675 return buf;
676}
677
678mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
679 bool turnBoxIntoClass) {
680 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
681 .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
682 return fir::SequenceType::get(seqTy.getShape(), newElementType);
683 })
684 .Case<fir::ReferenceType, fir::ClassType>([&](auto t) -> mlir::Type {
685 using FIRT = decltype(t);
686 auto newEleTy =
687 changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass);
688 return FIRT::get(newEleTy, t.isVolatile());
689 })
690 .Case<fir::PointerType, fir::HeapType>([&](auto t) -> mlir::Type {
691 using FIRT = decltype(t);
692 return FIRT::get(
693 changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass));
694 })
695 .Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
696 mlir::Type newInnerType =
697 changeElementType(t.getEleTy(), newElementType, false);
698 if (turnBoxIntoClass)
699 return fir::ClassType::get(newInnerType, t.isVolatile());
700 return fir::BoxType::get(newInnerType, t.isVolatile());
701 })
702 .Default([&](mlir::Type t) -> mlir::Type {
703 assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
704 llvm::isa<mlir::NoneType>(t)) &&
705 "unexpected FIR leaf type");
706 return newElementType;
707 });
708}
709
710} // namespace fir
711
712namespace {
713
714static llvm::SmallPtrSet<detail::RecordTypeStorage const *, 4>
715 recordTypeVisited;
716
717} // namespace
718
719void fir::verifyIntegralType(mlir::Type type) {
720 if (isaIntegerType(ty: type) || mlir::isa<mlir::IndexType>(Val: type))
721 return;
722 llvm::report_fatal_error(reason: "expected integral type");
723}
724
725void fir::printFirType(FIROpsDialect *, mlir::Type ty,
726 mlir::DialectAsmPrinter &p) {
727 if (mlir::failed(Result: generatedTypePrinter(ty, p)))
728 llvm::report_fatal_error(reason: "unknown type to print");
729}
730
731bool fir::isa_unknown_size_box(mlir::Type t) {
732 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(t)) {
733 auto valueType = fir::unwrapPassByRefType(boxTy);
734 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(valueType))
735 if (seqTy.hasUnknownShape())
736 return true;
737 }
738 return false;
739}
740
741bool fir::isa_volatile_type(mlir::Type t) {
742 return llvm::TypeSwitch<mlir::Type, bool>(t)
743 .Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(
744 [](auto t) { return t.isVolatile(); })
745 .Default([](mlir::Type) { return false; });
746}
747
748//===----------------------------------------------------------------------===//
749// BoxProcType
750//===----------------------------------------------------------------------===//
751
752// `boxproc` `<` return-type `>`
753mlir::Type BoxProcType::parse(mlir::AsmParser &parser) {
754 mlir::Type ty;
755 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater())
756 return {};
757 return get(parser.getContext(), ty);
758}
759
760void fir::BoxProcType::print(mlir::AsmPrinter &printer) const {
761 printer << "<" << getEleTy() << '>';
762}
763
764llvm::LogicalResult
765BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
766 mlir::Type eleTy) {
767 if (mlir::isa<mlir::FunctionType>(eleTy))
768 return mlir::success();
769 if (auto refTy = mlir::dyn_cast<ReferenceType>(eleTy))
770 if (mlir::isa<mlir::FunctionType>(refTy))
771 return mlir::success();
772 return emitError() << "invalid type for boxproc" << eleTy << '\n';
773}
774
775static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) {
776 return mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
777 SliceType, FieldType, LenType, HeapType, PointerType,
778 ReferenceType, TypeDescType>(eleTy);
779}
780
781//===----------------------------------------------------------------------===//
782// BoxType
783//===----------------------------------------------------------------------===//
784
785// `box` `<` type (`, volatile` $volatile^)? `>`
786mlir::Type fir::BoxType::parse(mlir::AsmParser &parser) {
787 mlir::Type eleTy;
788 auto location = parser.getCurrentLocation();
789 auto *context = parser.getContext();
790 bool isVolatile = false;
791 if (parser.parseLess() || parser.parseType(eleTy))
792 return {};
793 if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
794 return {};
795 if (parser.parseGreater())
796 return {};
797 return parser.getChecked<fir::BoxType>(location, context, eleTy, isVolatile);
798}
799
800void fir::BoxType::print(mlir::AsmPrinter &printer) const {
801 printer << "<" << getEleTy();
802 if (isVolatile())
803 printer << ", " << getVolatileKeyword();
804 printer << '>';
805}
806
807llvm::LogicalResult
808fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
809 mlir::Type eleTy, bool isVolatile) {
810 if (mlir::isa<fir::BaseBoxType>(eleTy))
811 return emitError() << "invalid element type\n";
812 // TODO
813 return mlir::success();
814}
815
816//===----------------------------------------------------------------------===//
817// BoxCharType
818//===----------------------------------------------------------------------===//
819
820mlir::Type fir::BoxCharType::parse(mlir::AsmParser &parser) {
821 return parseKindSingleton<fir::BoxCharType>(parser);
822}
823
824void fir::BoxCharType::print(mlir::AsmPrinter &printer) const {
825 printer << "<" << getKind() << ">";
826}
827
828CharacterType
829fir::BoxCharType::getElementType(mlir::MLIRContext *context) const {
830 return CharacterType::getUnknownLen(context, getKind());
831}
832
833CharacterType fir::BoxCharType::getEleTy() const {
834 return getElementType(getContext());
835}
836
837//===----------------------------------------------------------------------===//
838// CharacterType
839//===----------------------------------------------------------------------===//
840
841// `char` `<` kind [`,` `len`] `>`
842mlir::Type fir::CharacterType::parse(mlir::AsmParser &parser) {
843 int kind = 0;
844 if (parser.parseLess() || parser.parseInteger(kind))
845 return {};
846 CharacterType::LenType len = 1;
847 if (mlir::succeeded(parser.parseOptionalComma())) {
848 if (mlir::succeeded(parser.parseOptionalQuestion())) {
849 len = fir::CharacterType::unknownLen();
850 } else if (!mlir::succeeded(parser.parseInteger(len))) {
851 return {};
852 }
853 }
854 if (parser.parseGreater())
855 return {};
856 return get(parser.getContext(), kind, len);
857}
858
859void fir::CharacterType::print(mlir::AsmPrinter &printer) const {
860 printer << "<" << getFKind();
861 auto len = getLen();
862 if (len != fir::CharacterType::singleton()) {
863 printer << ',';
864 if (len == fir::CharacterType::unknownLen())
865 printer << '?';
866 else
867 printer << len;
868 }
869 printer << '>';
870}
871
872//===----------------------------------------------------------------------===//
873// ClassType
874//===----------------------------------------------------------------------===//
875
876// `class` `<` type (`, volatile` $volatile^)? `>`
877mlir::Type fir::ClassType::parse(mlir::AsmParser &parser) {
878 mlir::Type eleTy;
879 auto location = parser.getCurrentLocation();
880 auto *context = parser.getContext();
881 bool isVolatile = false;
882 if (parser.parseLess() || parser.parseType(eleTy))
883 return {};
884 if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
885 return {};
886 if (parser.parseGreater())
887 return {};
888 return parser.getChecked<fir::ClassType>(location, context, eleTy,
889 isVolatile);
890}
891
892void fir::ClassType::print(mlir::AsmPrinter &printer) const {
893 printer << "<" << getEleTy();
894 if (isVolatile())
895 printer << ", " << getVolatileKeyword();
896 printer << '>';
897}
898
899llvm::LogicalResult
900fir::ClassType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
901 mlir::Type eleTy, bool isVolatile) {
902 if (mlir::isa<fir::RecordType, fir::SequenceType, fir::HeapType,
903 fir::PointerType, mlir::NoneType, mlir::IntegerType,
904 mlir::FloatType, fir::CharacterType, fir::LogicalType,
905 mlir::ComplexType>(eleTy))
906 return mlir::success();
907 return emitError() << "invalid element type\n";
908}
909
910//===----------------------------------------------------------------------===//
911// HeapType
912//===----------------------------------------------------------------------===//
913
914// `heap` `<` type `>`
915mlir::Type fir::HeapType::parse(mlir::AsmParser &parser) {
916 return parseTypeSingleton<HeapType>(parser);
917}
918
919void fir::HeapType::print(mlir::AsmPrinter &printer) const {
920 printer << "<" << getEleTy() << '>';
921}
922
923llvm::LogicalResult
924fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
925 mlir::Type eleTy) {
926 if (cannotBePointerOrHeapElementType(eleTy))
927 return emitError() << "cannot build a heap pointer to type: " << eleTy
928 << '\n';
929 return mlir::success();
930}
931
932//===----------------------------------------------------------------------===//
933// IntegerType
934//===----------------------------------------------------------------------===//
935
936// `int` `<` kind `>`
937mlir::Type fir::IntegerType::parse(mlir::AsmParser &parser) {
938 return parseKindSingleton<fir::IntegerType>(parser);
939}
940
941void fir::IntegerType::print(mlir::AsmPrinter &printer) const {
942 printer << "<" << getFKind() << '>';
943}
944
945//===----------------------------------------------------------------------===//
946// UnsignedType
947//===----------------------------------------------------------------------===//
948
949// `unsigned` `<` kind `>`
950mlir::Type fir::UnsignedType::parse(mlir::AsmParser &parser) {
951 return parseKindSingleton<fir::UnsignedType>(parser);
952}
953
954void fir::UnsignedType::print(mlir::AsmPrinter &printer) const {
955 printer << "<" << getFKind() << '>';
956}
957
958//===----------------------------------------------------------------------===//
959// LogicalType
960//===----------------------------------------------------------------------===//
961
962// `logical` `<` kind `>`
963mlir::Type fir::LogicalType::parse(mlir::AsmParser &parser) {
964 return parseKindSingleton<fir::LogicalType>(parser);
965}
966
967void fir::LogicalType::print(mlir::AsmPrinter &printer) const {
968 printer << "<" << getFKind() << '>';
969}
970
971//===----------------------------------------------------------------------===//
972// PointerType
973//===----------------------------------------------------------------------===//
974
975// `ptr` `<` type `>`
976mlir::Type fir::PointerType::parse(mlir::AsmParser &parser) {
977 return parseTypeSingleton<fir::PointerType>(parser);
978}
979
980void fir::PointerType::print(mlir::AsmPrinter &printer) const {
981 printer << "<" << getEleTy() << '>';
982}
983
984llvm::LogicalResult fir::PointerType::verify(
985 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
986 mlir::Type eleTy) {
987 if (cannotBePointerOrHeapElementType(eleTy))
988 return emitError() << "cannot build a pointer to type: " << eleTy << '\n';
989 return mlir::success();
990}
991
992//===----------------------------------------------------------------------===//
993// RecordType
994//===----------------------------------------------------------------------===//
995
996// Fortran derived type
997// unpacked:
998// `type` `<` name
999// (`(` id `:` type (`,` id `:` type)* `)`)?
1000// (`{` id `:` type (`,` id `:` type)* `}`)? '>'
1001// packed:
1002// `type` `<` name
1003// (`(` id `:` type (`,` id `:` type)* `)`)?
1004// (`<{` id `:` type (`,` id `:` type)* `}>`)? '>'
1005mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
1006 llvm::StringRef name;
1007 if (parser.parseLess() || parser.parseKeyword(&name))
1008 return {};
1009 RecordType result = RecordType::get(parser.getContext(), name);
1010
1011 RecordType::TypeVector lenParamList;
1012 if (!parser.parseOptionalLParen()) {
1013 while (true) {
1014 llvm::StringRef lenparam;
1015 mlir::Type intTy;
1016 if (parser.parseKeyword(&lenparam) || parser.parseColon() ||
1017 parser.parseType(intTy)) {
1018 parser.emitError(parser.getNameLoc(), "expected LEN parameter list");
1019 return {};
1020 }
1021 lenParamList.emplace_back(lenparam, intTy);
1022 if (parser.parseOptionalComma())
1023 break;
1024 }
1025 if (parser.parseRParen())
1026 return {};
1027 }
1028
1029 RecordType::TypeVector typeList;
1030 if (!parser.parseOptionalLess()) {
1031 result.pack(true);
1032 }
1033
1034 if (!parser.parseOptionalLBrace()) {
1035 while (true) {
1036 llvm::StringRef field;
1037 mlir::Type fldTy;
1038 if (parser.parseKeyword(&field) || parser.parseColon() ||
1039 parser.parseType(fldTy)) {
1040 parser.emitError(parser.getNameLoc(), "expected field type list");
1041 return {};
1042 }
1043 typeList.emplace_back(field, fldTy);
1044 if (parser.parseOptionalComma())
1045 break;
1046 }
1047 if (parser.parseOptionalGreater()) {
1048 if (parser.parseRBrace())
1049 return {};
1050 }
1051 }
1052
1053 if (parser.parseGreater())
1054 return {};
1055
1056 if (lenParamList.empty() && typeList.empty())
1057 return result;
1058
1059 result.finalize(lenParamList, typeList);
1060 return verifyDerived(parser, result, lenParamList, typeList);
1061}
1062
1063void fir::RecordType::print(mlir::AsmPrinter &printer) const {
1064 printer << "<" << getName();
1065 if (!recordTypeVisited.count(uniqueKey())) {
1066 recordTypeVisited.insert(uniqueKey());
1067 if (getLenParamList().size()) {
1068 char ch = '(';
1069 for (auto p : getLenParamList()) {
1070 printer << ch << p.first << ':';
1071 p.second.print(printer.getStream());
1072 ch = ',';
1073 }
1074 printer << ')';
1075 }
1076 if (getTypeList().size()) {
1077 if (isPacked()) {
1078 printer << '<';
1079 }
1080 char ch = '{';
1081 for (auto p : getTypeList()) {
1082 printer << ch << p.first << ':';
1083 p.second.print(printer.getStream());
1084 ch = ',';
1085 }
1086 printer << '}';
1087 if (isPacked()) {
1088 printer << '>';
1089 }
1090 }
1091 recordTypeVisited.erase(uniqueKey());
1092 }
1093 printer << '>';
1094}
1095
1096void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
1097 llvm::ArrayRef<TypePair> typeList) {
1098 getImpl()->finalize(lenPList, typeList);
1099}
1100
1101llvm::StringRef fir::RecordType::getName() const {
1102 return getImpl()->getName();
1103}
1104
1105RecordType::TypeList fir::RecordType::getTypeList() const {
1106 return getImpl()->getTypeList();
1107}
1108
1109RecordType::TypeList fir::RecordType::getLenParamList() const {
1110 return getImpl()->getLenParamList();
1111}
1112
1113bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); }
1114
1115void fir::RecordType::pack(bool p) { getImpl()->pack(p); }
1116
1117bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); }
1118
1119detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
1120 return getImpl();
1121}
1122
1123llvm::LogicalResult fir::RecordType::verify(
1124 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1125 llvm::StringRef name) {
1126 if (name.size() == 0)
1127 return emitError() << "record types must have a name";
1128 return mlir::success();
1129}
1130
1131mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
1132 for (auto f : getTypeList())
1133 if (ident == f.first)
1134 return f.second;
1135 return {};
1136}
1137
1138unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
1139 for (auto f : llvm::enumerate(getTypeList()))
1140 if (ident == f.value().first)
1141 return f.index();
1142 return std::numeric_limits<unsigned>::max();
1143}
1144
1145//===----------------------------------------------------------------------===//
1146// ReferenceType
1147//===----------------------------------------------------------------------===//
1148
1149// `ref` `<` type (`, volatile` $volatile^)? `>`
1150mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
1151 auto location = parser.getCurrentLocation();
1152 auto *context = parser.getContext();
1153 mlir::Type eleTy;
1154 bool isVolatile = false;
1155 if (parser.parseLess() || parser.parseType(eleTy))
1156 return {};
1157 if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
1158 return {};
1159 if (parser.parseGreater())
1160 return {};
1161 return parser.getChecked<fir::ReferenceType>(location, context, eleTy,
1162 isVolatile);
1163}
1164
1165void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
1166 printer << "<" << getEleTy();
1167 if (isVolatile())
1168 printer << ", " << getVolatileKeyword();
1169 printer << '>';
1170}
1171
1172llvm::LogicalResult fir::ReferenceType::verify(
1173 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type eleTy,
1174 bool isVolatile) {
1175 if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
1176 ReferenceType, TypeDescType>(eleTy))
1177 return emitError() << "cannot build a reference to type: " << eleTy << '\n';
1178 return mlir::success();
1179}
1180
1181//===----------------------------------------------------------------------===//
1182// SequenceType
1183//===----------------------------------------------------------------------===//
1184
1185// `array` `<` `*` | bounds (`x` bounds)* `:` type (',' affine-map)? `>`
1186// bounds ::= `?` | int-lit
1187mlir::Type fir::SequenceType::parse(mlir::AsmParser &parser) {
1188 if (parser.parseLess())
1189 return {};
1190 SequenceType::Shape shape;
1191 if (parser.parseOptionalStar()) {
1192 if (parser.parseDimensionList(shape, /*allowDynamic=*/true))
1193 return {};
1194 } else if (parser.parseColon()) {
1195 return {};
1196 }
1197 mlir::Type eleTy;
1198 if (parser.parseType(eleTy))
1199 return {};
1200 mlir::AffineMapAttr map;
1201 if (!parser.parseOptionalComma()) {
1202 if (parser.parseAttribute(map)) {
1203 parser.emitError(parser.getNameLoc(), "expecting affine map");
1204 return {};
1205 }
1206 }
1207 if (parser.parseGreater())
1208 return {};
1209 return SequenceType::get(parser.getContext(), shape, eleTy, map);
1210}
1211
1212void fir::SequenceType::print(mlir::AsmPrinter &printer) const {
1213 auto shape = getShape();
1214 if (shape.size()) {
1215 printer << '<';
1216 for (const auto &b : shape) {
1217 if (b >= 0)
1218 printer << b << 'x';
1219 else
1220 printer << "?x";
1221 }
1222 } else {
1223 printer << "<*:";
1224 }
1225 printer << getEleTy();
1226 if (auto map = getLayoutMap()) {
1227 printer << ", ";
1228 map.print(printer.getStream());
1229 }
1230 printer << '>';
1231}
1232
1233unsigned fir::SequenceType::getConstantRows() const {
1234 if (hasDynamicSize(getEleTy()))
1235 return 0;
1236 auto shape = getShape();
1237 unsigned count = 0;
1238 for (auto d : shape) {
1239 if (d == getUnknownExtent())
1240 break;
1241 ++count;
1242 }
1243 return count;
1244}
1245
1246llvm::LogicalResult fir::SequenceType::verify(
1247 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1248 llvm::ArrayRef<int64_t> shape, mlir::Type eleTy,
1249 mlir::AffineMapAttr layoutMap) {
1250 // DIMENSION attribute can only be applied to an intrinsic or record type
1251 if (mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
1252 ShiftType, SliceType, FieldType, LenType, HeapType, PointerType,
1253 ReferenceType, TypeDescType, SequenceType>(eleTy))
1254 return emitError() << "cannot build an array of this element type: "
1255 << eleTy << '\n';
1256 return mlir::success();
1257}
1258
1259//===----------------------------------------------------------------------===//
1260// ShapeType
1261//===----------------------------------------------------------------------===//
1262
1263mlir::Type fir::ShapeType::parse(mlir::AsmParser &parser) {
1264 return parseRankSingleton<fir::ShapeType>(parser);
1265}
1266
1267void fir::ShapeType::print(mlir::AsmPrinter &printer) const {
1268 printer << "<" << getImpl()->rank << ">";
1269}
1270
1271//===----------------------------------------------------------------------===//
1272// ShapeShiftType
1273//===----------------------------------------------------------------------===//
1274
1275mlir::Type fir::ShapeShiftType::parse(mlir::AsmParser &parser) {
1276 return parseRankSingleton<fir::ShapeShiftType>(parser);
1277}
1278
1279void fir::ShapeShiftType::print(mlir::AsmPrinter &printer) const {
1280 printer << "<" << getRank() << ">";
1281}
1282
1283//===----------------------------------------------------------------------===//
1284// ShiftType
1285//===----------------------------------------------------------------------===//
1286
1287mlir::Type fir::ShiftType::parse(mlir::AsmParser &parser) {
1288 return parseRankSingleton<fir::ShiftType>(parser);
1289}
1290
1291void fir::ShiftType::print(mlir::AsmPrinter &printer) const {
1292 printer << "<" << getRank() << ">";
1293}
1294
1295//===----------------------------------------------------------------------===//
1296// SliceType
1297//===----------------------------------------------------------------------===//
1298
1299// `slice` `<` rank `>`
1300mlir::Type fir::SliceType::parse(mlir::AsmParser &parser) {
1301 return parseRankSingleton<fir::SliceType>(parser);
1302}
1303
1304void fir::SliceType::print(mlir::AsmPrinter &printer) const {
1305 printer << "<" << getRank() << '>';
1306}
1307
1308//===----------------------------------------------------------------------===//
1309// TypeDescType
1310//===----------------------------------------------------------------------===//
1311
1312// `tdesc` `<` type `>`
1313mlir::Type fir::TypeDescType::parse(mlir::AsmParser &parser) {
1314 return parseTypeSingleton<fir::TypeDescType>(parser);
1315}
1316
1317void fir::TypeDescType::print(mlir::AsmPrinter &printer) const {
1318 printer << "<" << getOfTy() << '>';
1319}
1320
1321llvm::LogicalResult fir::TypeDescType::verify(
1322 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1323 mlir::Type eleTy) {
1324 if (mlir::isa<BoxType, BoxCharType, BoxProcType, ShapeType, ShapeShiftType,
1325 ShiftType, SliceType, FieldType, LenType, ReferenceType,
1326 TypeDescType>(eleTy))
1327 return emitError() << "cannot build a type descriptor of type: " << eleTy
1328 << '\n';
1329 return mlir::success();
1330}
1331
1332//===----------------------------------------------------------------------===//
1333// VectorType
1334//===----------------------------------------------------------------------===//
1335
1336// `vector` `<` len `:` type `>`
1337mlir::Type fir::VectorType::parse(mlir::AsmParser &parser) {
1338 int64_t len = 0;
1339 mlir::Type eleTy;
1340 if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() ||
1341 parser.parseType(eleTy) || parser.parseGreater())
1342 return {};
1343 return fir::VectorType::get(len, eleTy);
1344}
1345
1346void fir::VectorType::print(mlir::AsmPrinter &printer) const {
1347 printer << "<" << getLen() << ':' << getEleTy() << '>';
1348}
1349
1350llvm::LogicalResult fir::VectorType::verify(
1351 llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
1352 mlir::Type eleTy) {
1353 if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
1354 return emitError() << "cannot build a vector of type " << eleTy << '\n';
1355 return mlir::success();
1356}
1357
1358bool fir::VectorType::isValidElementType(mlir::Type t) {
1359 return isa_real(t) || isa_integer(t);
1360}
1361
1362bool fir::isCharacterProcedureTuple(mlir::Type ty, bool acceptRawFunc) {
1363 mlir::TupleType tuple = mlir::dyn_cast<mlir::TupleType>(Val&: ty);
1364 return tuple && tuple.size() == 2 &&
1365 (mlir::isa<fir::BoxProcType>(tuple.getType(0)) ||
1366 (acceptRawFunc && mlir::isa<mlir::FunctionType>(tuple.getType(0)))) &&
1367 fir::isa_integer(tuple.getType(1));
1368}
1369
1370bool fir::hasAbstractResult(mlir::FunctionType ty) {
1371 if (ty.getNumResults() == 0)
1372 return false;
1373 auto resultType = ty.getResult(i: 0);
1374 return mlir::isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>(
1375 resultType);
1376}
1377
1378/// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error
1379/// messages only.
1380mlir::Type fir::fromRealTypeID(mlir::MLIRContext *context,
1381 llvm::Type::TypeID typeID, fir::KindTy kind) {
1382 switch (typeID) {
1383 case llvm::Type::TypeID::HalfTyID:
1384 return mlir::Float16Type::get(context);
1385 case llvm::Type::TypeID::BFloatTyID:
1386 return mlir::BFloat16Type::get(context);
1387 case llvm::Type::TypeID::FloatTyID:
1388 return mlir::Float32Type::get(context);
1389 case llvm::Type::TypeID::DoubleTyID:
1390 return mlir::Float64Type::get(context);
1391 case llvm::Type::TypeID::X86_FP80TyID:
1392 return mlir::Float80Type::get(context);
1393 case llvm::Type::TypeID::FP128TyID:
1394 return mlir::Float128Type::get(context);
1395 default:
1396 mlir::emitError(loc: mlir::UnknownLoc::get(context))
1397 << "unsupported type: !fir.real<" << kind << ">";
1398 return {};
1399 }
1400}
1401
1402//===----------------------------------------------------------------------===//
1403// BaseBoxType
1404//===----------------------------------------------------------------------===//
1405
1406mlir::Type BaseBoxType::getEleTy() const {
1407 return llvm::TypeSwitch<fir::BaseBoxType, mlir::Type>(*this)
1408 .Case<fir::BoxType, fir::ClassType>(
1409 [](auto type) { return type.getEleTy(); });
1410}
1411
1412mlir::Type BaseBoxType::getBaseAddressType() const {
1413 mlir::Type eleTy = getEleTy();
1414 if (fir::isa_ref_type(eleTy))
1415 return eleTy;
1416 return fir::ReferenceType::get(eleTy, isVolatile());
1417}
1418
1419mlir::Type BaseBoxType::unwrapInnerType() const {
1420 return fir::unwrapInnerType(getEleTy());
1421}
1422
1423static mlir::Type
1424changeTypeShape(mlir::Type type,
1425 std::optional<fir::SequenceType::ShapeRef> newShape) {
1426 return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
1427 .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
1428 if (newShape)
1429 return fir::SequenceType::get(*newShape, seqTy.getEleTy());
1430 return seqTy.getEleTy();
1431 })
1432 .Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(
1433 [&](auto t) -> mlir::Type {
1434 using FIRT = decltype(t);
1435 return FIRT::get(changeTypeShape(t.getEleTy(), newShape),
1436 t.isVolatile());
1437 })
1438 .Case<fir::PointerType, fir::HeapType>([&](auto t) -> mlir::Type {
1439 using FIRT = decltype(t);
1440 return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
1441 })
1442 .Default([&](mlir::Type t) -> mlir::Type {
1443 assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
1444 llvm::isa<mlir::NoneType>(t) ||
1445 llvm::isa<fir::CharacterType>(t)) &&
1446 "unexpected FIR leaf type");
1447 if (newShape)
1448 return fir::SequenceType::get(*newShape, t);
1449 return t;
1450 });
1451}
1452
1453fir::BaseBoxType
1454fir::BaseBoxType::getBoxTypeWithNewShape(mlir::Type shapeMold) const {
1455 fir::SequenceType seqTy = fir::unwrapUntilSeqType(shapeMold);
1456 std::optional<fir::SequenceType::ShapeRef> newShape;
1457 if (seqTy)
1458 newShape = seqTy.getShape();
1459 return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
1460}
1461
1462fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewShape(int rank) const {
1463 std::optional<fir::SequenceType::ShapeRef> newShape;
1464 fir::SequenceType::Shape shapeVector;
1465 if (rank > 0) {
1466 shapeVector =
1467 fir::SequenceType::Shape(rank, fir::SequenceType::getUnknownExtent());
1468 newShape = shapeVector;
1469 }
1470 return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
1471}
1472
1473fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewAttr(
1474 fir::BaseBoxType::Attribute attr) const {
1475 mlir::Type baseType = fir::unwrapRefType(getEleTy());
1476 switch (attr) {
1477 case fir::BaseBoxType::Attribute::None:
1478 break;
1479 case fir::BaseBoxType::Attribute::Allocatable:
1480 baseType = fir::HeapType::get(baseType);
1481 break;
1482 case fir::BaseBoxType::Attribute::Pointer:
1483 baseType = fir::PointerType::get(baseType);
1484 break;
1485 }
1486 return llvm::TypeSwitch<fir::BaseBoxType, fir::BaseBoxType>(*this)
1487 .Case<fir::BoxType>([baseType](auto b) {
1488 return fir::BoxType::get(baseType, b.isVolatile());
1489 })
1490 .Case<fir::ClassType>([baseType](auto b) {
1491 return fir::ClassType::get(baseType, b.isVolatile());
1492 });
1493}
1494
1495bool fir::BaseBoxType::isAssumedRank() const {
1496 if (auto seqTy =
1497 mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(getEleTy())))
1498 return seqTy.hasUnknownShape();
1499 return false;
1500}
1501
1502bool fir::BaseBoxType::isPointer() const {
1503 return llvm::isa<fir::PointerType>(getEleTy());
1504}
1505
1506bool fir::BaseBoxType::isPointerOrAllocatable() const {
1507 return llvm::isa<fir::PointerType, fir::HeapType>(getEleTy());
1508}
1509
1510bool BaseBoxType::isVolatile() const { return fir::isa_volatile_type(*this); }
1511
1512//===----------------------------------------------------------------------===//
1513// FIROpsDialect
1514//===----------------------------------------------------------------------===//
1515
1516void FIROpsDialect::registerTypes() {
1517 addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, ClassType,
1518 FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
1519 LLVMPointerType, PointerType, RecordType, ReferenceType,
1520 SequenceType, ShapeType, ShapeShiftType, ShiftType, SliceType,
1521 TypeDescType, fir::VectorType, fir::DummyScopeType>();
1522 fir::ReferenceType::attachInterface<
1523 OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext());
1524 fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>(
1525 *getContext());
1526 fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>(
1527 *getContext());
1528 fir::LLVMPointerType::attachInterface<
1529 OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext());
1530}
1531
1532std::optional<std::pair<uint64_t, unsigned short>>
1533fir::getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
1534 const mlir::DataLayout &dl,
1535 const fir::KindMapping &kindMap) {
1536 if (ty.isIntOrIndexOrFloat() ||
1537 mlir::isa<mlir::ComplexType, mlir::VectorType,
1538 mlir::DataLayoutTypeInterface>(ty)) {
1539 llvm::TypeSize size = dl.getTypeSize(ty);
1540 unsigned short alignment = dl.getTypeABIAlignment(ty);
1541 return std::pair{size, alignment};
1542 }
1543 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
1544 auto result = getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
1545 if (!result)
1546 return result;
1547 auto [eleSize, eleAlign] = *result;
1548 std::uint64_t size =
1549 llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
1550 return std::pair{size, eleAlign};
1551 }
1552 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
1553 std::uint64_t size = 0;
1554 unsigned short align = 1;
1555 for (auto component : recTy.getTypeList()) {
1556 auto result = getTypeSizeAndAlignment(loc, component.second, dl, kindMap);
1557 if (!result)
1558 return result;
1559 auto [compSize, compAlign] = *result;
1560 size =
1561 llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
1562 align = std::max(align, compAlign);
1563 }
1564 return std::pair{size, align};
1565 }
1566 if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
1567 mlir::Type intTy = mlir::IntegerType::get(
1568 context: logical.getContext(), width: kindMap.getLogicalBitsize(logical.getFKind()));
1569 return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1570 }
1571 if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
1572 mlir::Type intTy = mlir::IntegerType::get(
1573 context: character.getContext(),
1574 width: kindMap.getCharacterBitsize(character.getFKind()));
1575 auto result = getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
1576 if (!result)
1577 return result;
1578 auto [compSize, compAlign] = *result;
1579 if (character.hasConstantLen())
1580 compSize *= character.getLen();
1581 return std::pair{compSize, compAlign};
1582 }
1583 return std::nullopt;
1584}
1585
1586std::pair<std::uint64_t, unsigned short>
1587fir::getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty,
1588 const mlir::DataLayout &dl,
1589 const fir::KindMapping &kindMap) {
1590 std::optional<std::pair<uint64_t, unsigned short>> result =
1591 getTypeSizeAndAlignment(loc, ty, dl, kindMap);
1592 if (result)
1593 return *result;
1594 TODO(loc, "computing size of a component");
1595}
1596

source code of flang/lib/Optimizer/Dialect/FIRType.cpp