1 | //===- Types.cpp - MLIR Type Classes --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/BuiltinTypes.h" |
10 | #include "mlir/IR/Dialect.h" |
11 | |
12 | using namespace mlir; |
13 | using namespace mlir::detail; |
14 | |
15 | //===----------------------------------------------------------------------===// |
16 | // AbstractType |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | void AbstractType::walkImmediateSubElements( |
20 | Type type, function_ref<void(Attribute)> walkAttrsFn, |
21 | function_ref<void(Type)> walkTypesFn) const { |
22 | walkImmediateSubElementsFn(type, walkAttrsFn, walkTypesFn); |
23 | } |
24 | |
25 | Type AbstractType::replaceImmediateSubElements(Type type, |
26 | ArrayRef<Attribute> replAttrs, |
27 | ArrayRef<Type> replTypes) const { |
28 | return replaceImmediateSubElementsFn(type, replAttrs, replTypes); |
29 | } |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // Type |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | MLIRContext *Type::getContext() const { return getDialect().getContext(); } |
36 | |
37 | bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); } |
38 | bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); } |
39 | bool Type::isFloat8E5M2FNUZ() const { |
40 | return llvm::isa<Float8E5M2FNUZType>(*this); |
41 | } |
42 | bool Type::isFloat8E4M3FNUZ() const { |
43 | return llvm::isa<Float8E4M3FNUZType>(*this); |
44 | } |
45 | bool Type::isFloat8E4M3B11FNUZ() const { |
46 | return llvm::isa<Float8E4M3B11FNUZType>(*this); |
47 | } |
48 | bool Type::isBF16() const { return llvm::isa<BFloat16Type>(Val: *this); } |
49 | bool Type::isF16() const { return llvm::isa<Float16Type>(Val: *this); } |
50 | bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(Val: *this); } |
51 | bool Type::isF32() const { return llvm::isa<Float32Type>(Val: *this); } |
52 | bool Type::isF64() const { return llvm::isa<Float64Type>(*this); } |
53 | bool Type::isF80() const { return llvm::isa<Float80Type>(*this); } |
54 | bool Type::isF128() const { return llvm::isa<Float128Type>(*this); } |
55 | |
56 | bool Type::isIndex() const { return llvm::isa<IndexType>(Val: *this); } |
57 | |
58 | bool Type::isInteger() const { return llvm::isa<IntegerType>(Val: *this); } |
59 | |
60 | /// Return true if this is an integer type with the specified width. |
61 | bool Type::isInteger(unsigned width) const { |
62 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
63 | return intTy.getWidth() == width; |
64 | return false; |
65 | } |
66 | |
67 | bool Type::isSignlessInteger() const { |
68 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
69 | return intTy.isSignless(); |
70 | return false; |
71 | } |
72 | |
73 | bool Type::isSignlessInteger(unsigned width) const { |
74 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
75 | return intTy.isSignless() && intTy.getWidth() == width; |
76 | return false; |
77 | } |
78 | |
79 | bool Type::isSignedInteger() const { |
80 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
81 | return intTy.isSigned(); |
82 | return false; |
83 | } |
84 | |
85 | bool Type::isSignedInteger(unsigned width) const { |
86 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
87 | return intTy.isSigned() && intTy.getWidth() == width; |
88 | return false; |
89 | } |
90 | |
91 | bool Type::isUnsignedInteger() const { |
92 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
93 | return intTy.isUnsigned(); |
94 | return false; |
95 | } |
96 | |
97 | bool Type::isUnsignedInteger(unsigned width) const { |
98 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
99 | return intTy.isUnsigned() && intTy.getWidth() == width; |
100 | return false; |
101 | } |
102 | |
103 | bool Type::isSignlessIntOrIndex() const { |
104 | return isSignlessInteger() || llvm::isa<IndexType>(Val: *this); |
105 | } |
106 | |
107 | bool Type::isSignlessIntOrIndexOrFloat() const { |
108 | return isSignlessInteger() || llvm::isa<IndexType, FloatType>(Val: *this); |
109 | } |
110 | |
111 | bool Type::isSignlessIntOrFloat() const { |
112 | return isSignlessInteger() || llvm::isa<FloatType>(Val: *this); |
113 | } |
114 | |
115 | bool Type::isIntOrIndex() const { |
116 | return llvm::isa<IntegerType>(Val: *this) || isIndex(); |
117 | } |
118 | |
119 | bool Type::isIntOrFloat() const { |
120 | return llvm::isa<IntegerType, FloatType>(Val: *this); |
121 | } |
122 | |
123 | bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); } |
124 | |
125 | unsigned Type::getIntOrFloatBitWidth() const { |
126 | assert(isIntOrFloat() && "only integers and floats have a bitwidth" ); |
127 | if (auto intType = llvm::dyn_cast<IntegerType>(*this)) |
128 | return intType.getWidth(); |
129 | return llvm::cast<FloatType>(Val: *this).getWidth(); |
130 | } |
131 | |