1 | //===- Types.cpp ----------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/PDLL/AST/Types.h" |
10 | #include "TypeDetail.h" |
11 | #include "mlir/Tools/PDLL/AST/Context.h" |
12 | #include <optional> |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::pdll; |
16 | using namespace mlir::pdll::ast; |
17 | |
18 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage) |
19 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage) |
20 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::OperationTypeStorage) |
21 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RangeTypeStorage) |
22 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RewriteTypeStorage) |
23 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TupleTypeStorage) |
24 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TypeTypeStorage) |
25 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ValueTypeStorage) |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // Type |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | TypeID Type::getTypeID() const { return impl->typeID; } |
32 | |
33 | Type Type::refineWith(Type other) const { |
34 | if (*this == other) |
35 | return *this; |
36 | |
37 | // Operation types are compatible if the operation names don't conflict. |
38 | if (auto opTy = mlir::dyn_cast<OperationType>(Val: *this)) { |
39 | auto otherOpTy = mlir::dyn_cast<ast::OperationType>(Val&: other); |
40 | if (!otherOpTy) |
41 | return nullptr; |
42 | if (!otherOpTy.getName()) |
43 | return *this; |
44 | if (!opTy.getName()) |
45 | return other; |
46 | |
47 | return nullptr; |
48 | } |
49 | |
50 | return nullptr; |
51 | } |
52 | |
53 | //===----------------------------------------------------------------------===// |
54 | // AttributeType |
55 | //===----------------------------------------------------------------------===// |
56 | |
57 | AttributeType AttributeType::get(Context &context) { |
58 | return context.getTypeUniquer().get<ImplTy>(); |
59 | } |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // ConstraintType |
63 | //===----------------------------------------------------------------------===// |
64 | |
65 | ConstraintType ConstraintType::get(Context &context) { |
66 | return context.getTypeUniquer().get<ImplTy>(); |
67 | } |
68 | |
69 | //===----------------------------------------------------------------------===// |
70 | // OperationType |
71 | //===----------------------------------------------------------------------===// |
72 | |
73 | OperationType OperationType::get(Context &context, |
74 | std::optional<StringRef> name, |
75 | const ods::Operation *odsOp) { |
76 | return context.getTypeUniquer().get<ImplTy>( |
77 | /*initFn=*/function_ref<void(ImplTy *)>(), |
78 | args: std::make_pair(x: name.value_or(u: "" ), y&: odsOp)); |
79 | } |
80 | |
81 | std::optional<StringRef> OperationType::getName() const { |
82 | StringRef name = getImplAs<ImplTy>()->getValue().first; |
83 | return name.empty() ? std::optional<StringRef>() |
84 | : std::optional<StringRef>(name); |
85 | } |
86 | |
87 | const ods::Operation *OperationType::getODSOperation() const { |
88 | return getImplAs<ImplTy>()->getValue().second; |
89 | } |
90 | |
91 | //===----------------------------------------------------------------------===// |
92 | // RangeType |
93 | //===----------------------------------------------------------------------===// |
94 | |
95 | RangeType RangeType::get(Context &context, Type elementType) { |
96 | return context.getTypeUniquer().get<ImplTy>( |
97 | /*initFn=*/function_ref<void(ImplTy *)>(), args&: elementType); |
98 | } |
99 | |
100 | Type RangeType::getElementType() const { |
101 | return getImplAs<ImplTy>()->getValue(); |
102 | } |
103 | |
104 | //===----------------------------------------------------------------------===// |
105 | // TypeRangeType |
106 | //===----------------------------------------------------------------------===// |
107 | |
108 | bool TypeRangeType::classof(Type type) { |
109 | RangeType range = mlir::dyn_cast<RangeType>(Val&: type); |
110 | return range && mlir::isa<TypeType>(Val: range.getElementType()); |
111 | } |
112 | |
113 | TypeRangeType TypeRangeType::get(Context &context) { |
114 | return mlir::cast<TypeRangeType>( |
115 | Val: RangeType::get(context, elementType: TypeType::get(context))); |
116 | } |
117 | |
118 | //===----------------------------------------------------------------------===// |
119 | // ValueRangeType |
120 | //===----------------------------------------------------------------------===// |
121 | |
122 | bool ValueRangeType::classof(Type type) { |
123 | RangeType range = mlir::dyn_cast<RangeType>(Val&: type); |
124 | return range && mlir::isa<ValueType>(Val: range.getElementType()); |
125 | } |
126 | |
127 | ValueRangeType ValueRangeType::get(Context &context) { |
128 | return mlir::cast<ValueRangeType>( |
129 | Val: RangeType::get(context, elementType: ValueType::get(context))); |
130 | } |
131 | |
132 | //===----------------------------------------------------------------------===// |
133 | // RewriteType |
134 | //===----------------------------------------------------------------------===// |
135 | |
136 | RewriteType RewriteType::get(Context &context) { |
137 | return context.getTypeUniquer().get<ImplTy>(); |
138 | } |
139 | |
140 | //===----------------------------------------------------------------------===// |
141 | // TupleType |
142 | //===----------------------------------------------------------------------===// |
143 | |
144 | TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes, |
145 | ArrayRef<StringRef> elementNames) { |
146 | assert(elementTypes.size() == elementNames.size()); |
147 | return context.getTypeUniquer().get<ImplTy>( |
148 | /*initFn=*/function_ref<void(ImplTy *)>(), args&: elementTypes, args&: elementNames); |
149 | } |
150 | TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes) { |
151 | SmallVector<StringRef> elementNames(elementTypes.size()); |
152 | return get(context, elementTypes, elementNames); |
153 | } |
154 | |
155 | ArrayRef<Type> TupleType::getElementTypes() const { |
156 | return getImplAs<ImplTy>()->getValue().first; |
157 | } |
158 | |
159 | ArrayRef<StringRef> TupleType::getElementNames() const { |
160 | return getImplAs<ImplTy>()->getValue().second; |
161 | } |
162 | |
163 | //===----------------------------------------------------------------------===// |
164 | // TypeType |
165 | //===----------------------------------------------------------------------===// |
166 | |
167 | TypeType TypeType::get(Context &context) { |
168 | return context.getTypeUniquer().get<ImplTy>(); |
169 | } |
170 | |
171 | //===----------------------------------------------------------------------===// |
172 | // ValueType |
173 | //===----------------------------------------------------------------------===// |
174 | |
175 | ValueType ValueType::get(Context &context) { |
176 | return context.getTypeUniquer().get<ImplTy>(); |
177 | } |
178 | |