1 | //===- Types.cpp - MLIR Type Classes --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/BuiltinTypes.h" |
10 | #include "mlir/IR/Dialect.h" |
11 | |
12 | using namespace mlir; |
13 | using namespace mlir::detail; |
14 | |
15 | //===----------------------------------------------------------------------===// |
16 | // AbstractType |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | void AbstractType::walkImmediateSubElements( |
20 | Type type, function_ref<void(Attribute)> walkAttrsFn, |
21 | function_ref<void(Type)> walkTypesFn) const { |
22 | walkImmediateSubElementsFn(type, walkAttrsFn, walkTypesFn); |
23 | } |
24 | |
25 | Type AbstractType::replaceImmediateSubElements(Type type, |
26 | ArrayRef<Attribute> replAttrs, |
27 | ArrayRef<Type> replTypes) const { |
28 | return replaceImmediateSubElementsFn(type, replAttrs, replTypes); |
29 | } |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // Type |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | MLIRContext *Type::getContext() const { return getDialect().getContext(); } |
36 | |
37 | bool Type::isBF16() const { return llvm::isa<BFloat16Type>(Val: *this); } |
38 | bool Type::isF16() const { return llvm::isa<Float16Type>(Val: *this); } |
39 | bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(Val: *this); } |
40 | bool Type::isF32() const { return llvm::isa<Float32Type>(Val: *this); } |
41 | bool Type::isF64() const { return llvm::isa<Float64Type>(Val: *this); } |
42 | bool Type::isF80() const { return llvm::isa<Float80Type>(Val: *this); } |
43 | bool Type::isF128() const { return llvm::isa<Float128Type>(Val: *this); } |
44 | |
45 | bool Type::isFloat() const { return llvm::isa<FloatType>(Val: *this); } |
46 | |
47 | /// Return true if this is a float type with the specified width. |
48 | bool Type::isFloat(unsigned width) const { |
49 | if (auto fltTy = llvm::dyn_cast<FloatType>(*this)) |
50 | return fltTy.getWidth() == width; |
51 | return false; |
52 | } |
53 | |
54 | bool Type::isIndex() const { return llvm::isa<IndexType>(Val: *this); } |
55 | |
56 | bool Type::isInteger() const { return llvm::isa<IntegerType>(Val: *this); } |
57 | |
58 | bool Type::isInteger(unsigned width) const { |
59 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
60 | return intTy.getWidth() == width; |
61 | return false; |
62 | } |
63 | |
64 | bool Type::isSignlessInteger() const { |
65 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
66 | return intTy.isSignless(); |
67 | return false; |
68 | } |
69 | |
70 | bool Type::isSignlessInteger(unsigned width) const { |
71 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
72 | return intTy.isSignless() && intTy.getWidth() == width; |
73 | return false; |
74 | } |
75 | |
76 | bool Type::isSignedInteger() const { |
77 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
78 | return intTy.isSigned(); |
79 | return false; |
80 | } |
81 | |
82 | bool Type::isSignedInteger(unsigned width) const { |
83 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
84 | return intTy.isSigned() && intTy.getWidth() == width; |
85 | return false; |
86 | } |
87 | |
88 | bool Type::isUnsignedInteger() const { |
89 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
90 | return intTy.isUnsigned(); |
91 | return false; |
92 | } |
93 | |
94 | bool Type::isUnsignedInteger(unsigned width) const { |
95 | if (auto intTy = llvm::dyn_cast<IntegerType>(*this)) |
96 | return intTy.isUnsigned() && intTy.getWidth() == width; |
97 | return false; |
98 | } |
99 | |
100 | bool Type::isSignlessIntOrIndex() const { |
101 | return isSignlessInteger() || llvm::isa<IndexType>(Val: *this); |
102 | } |
103 | |
104 | bool Type::isSignlessIntOrIndexOrFloat() const { |
105 | return isSignlessInteger() || llvm::isa<IndexType, FloatType>(Val: *this); |
106 | } |
107 | |
108 | bool Type::isSignlessIntOrFloat() const { |
109 | return isSignlessInteger() || llvm::isa<FloatType>(Val: *this); |
110 | } |
111 | |
112 | bool Type::isIntOrIndex() const { |
113 | return llvm::isa<IntegerType>(Val: *this) || isIndex(); |
114 | } |
115 | |
116 | bool Type::isIntOrFloat() const { |
117 | return llvm::isa<IntegerType, FloatType>(Val: *this); |
118 | } |
119 | |
120 | bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); } |
121 | |
122 | unsigned Type::getIntOrFloatBitWidth() const { |
123 | assert(isIntOrFloat() && "only integers and floats have a bitwidth" ); |
124 | if (auto intType = llvm::dyn_cast<IntegerType>(*this)) |
125 | return intType.getWidth(); |
126 | return llvm::cast<FloatType>(*this).getWidth(); |
127 | } |
128 | |