1 | //===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains the definitions of the infer op interfaces defined in |
10 | // `InferTypeOpInterface.td`. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ |
15 | #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ |
16 | |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "mlir/IR/Location.h" |
21 | #include "mlir/IR/OpDefinition.h" |
22 | #include "mlir/Support/LLVM.h" |
23 | #include "llvm/ADT/PointerUnion.h" |
24 | #include "llvm/ADT/SmallVector.h" |
25 | |
26 | namespace mlir { |
27 | |
28 | class ShapedTypeComponents; |
29 | using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>; |
30 | |
31 | /// Reify the shape of the result of an operation (typically in terms of the |
32 | /// shape of its operands). |
33 | LogicalResult |
34 | reifyResultShapes(OpBuilder &b, Operation *op, |
35 | ReifiedRankedShapedTypeDims &reifiedReturnShapes); |
36 | |
37 | /// Adaptor class to abstract the differences between whether value is from |
38 | /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. |
39 | class ShapeAdaptor { |
40 | public: |
41 | ShapeAdaptor(Type t) { |
42 | if (auto st = dyn_cast<ShapedType>(t)) |
43 | val = st; |
44 | } |
45 | ShapeAdaptor(Attribute t) { |
46 | if (auto da = dyn_cast<DenseIntElementsAttr>(t)) |
47 | val = da; |
48 | } |
49 | ShapeAdaptor(ShapedTypeComponents *components) : val(components) {} |
50 | ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {} |
51 | |
52 | /// Returns whether the shape has a rank. |
53 | bool hasRank() const; |
54 | |
55 | /// Returns the element type. |
56 | Type getElementType() const; |
57 | |
58 | /// Populates the dimensions from shape referenced. |
59 | /// Requires: shape is ranked. |
60 | void getDims(SmallVectorImpl<int64_t> &res) const; |
61 | |
62 | /// Populates the dimensions of the ShapeTypeComponents. |
63 | /// Requires: shape is ranked. |
64 | void getDims(ShapedTypeComponents &res) const; |
65 | |
66 | /// Returns the size of the index'th dimension. |
67 | /// Requires: shape is ranked. |
68 | int64_t getDimSize(int index) const; |
69 | |
70 | /// Returns whether the index'th dimension is dynamic. |
71 | /// Requires: shape is ranked. |
72 | bool isDynamicDim(int index) const { |
73 | return ShapedType::isDynamic(getDimSize(index)); |
74 | } |
75 | |
76 | /// Returns whether the shape is fully static. |
77 | bool hasStaticShape() const; |
78 | |
79 | /// Returns the rank of the shape. |
80 | /// Requires: shape is ranked. |
81 | int64_t getRank() const; |
82 | |
83 | /// Returns the number of elements in the shape. |
84 | /// Requires: hasStaticShape |
85 | int64_t getNumElements() const; |
86 | |
87 | /// Returns whether valid (non-null) shape. |
88 | explicit operator bool() const { return !val.isNull(); } |
89 | |
90 | /// Dumps textual repesentation to stderr. |
91 | void dump() const; |
92 | |
93 | private: |
94 | // Union storing either ShapedTypeComponents, ShapedType (stored as Type and |
95 | // casted), or DenseIntElementsAttribute (stored as Atrtribute). |
96 | PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr; |
97 | }; |
98 | |
99 | /// ShapedTypeComponents that represents the components of a ShapedType. |
100 | /// The components consist of |
101 | /// - A ranked or unranked shape with the dimension specification match those |
102 | /// of ShapeType's getShape() (e.g., dynamic dimension represented using |
103 | /// ShapedType::kDynamic) |
104 | /// - A element type, may be unset (nullptr) |
105 | /// - A attribute, may be unset (nullptr) |
106 | /// Used by ShapedType type inferences. |
107 | class ShapedTypeComponents { |
108 | /// Internal storage type for shape. |
109 | using ShapeStorageT = SmallVector<int64_t, 3>; |
110 | |
111 | public: |
112 | /// Default construction is an unranked shape. |
113 | ShapedTypeComponents() : elementType(nullptr), attr(nullptr){}; |
114 | ShapedTypeComponents(Type elementType) |
115 | : elementType(elementType), attr(nullptr), ranked(false) {} |
116 | ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) { |
117 | ranked = shapedType.hasRank(); |
118 | elementType = shapedType.getElementType(); |
119 | if (ranked) |
120 | dims = llvm::to_vector<4>(shapedType.getShape()); |
121 | } |
122 | ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) { |
123 | ranked = adaptor.hasRank(); |
124 | elementType = adaptor.getElementType(); |
125 | if (ranked) |
126 | adaptor.getDims(res&: *this); |
127 | } |
128 | template <typename Arg, typename = std::enable_if_t< |
129 | std::is_constructible<ShapeStorageT, Arg>::value>> |
130 | ShapedTypeComponents(Arg &&arg, Type elementType = nullptr, |
131 | Attribute attr = nullptr) |
132 | : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr), |
133 | ranked(true) {} |
134 | ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr, |
135 | Attribute attr = nullptr) |
136 | : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr), |
137 | ranked(true) {} |
138 | |
139 | /// Return the dimensions of the shape. |
140 | /// Requires: shape is ranked. |
141 | ArrayRef<int64_t> getDims() const { |
142 | assert(ranked && "requires ranked shape" ); |
143 | return dims; |
144 | } |
145 | |
146 | /// Return whether the shape has a rank. |
147 | bool hasRank() const { return ranked; }; |
148 | |
149 | /// Return the element type component. |
150 | Type getElementType() const { return elementType; }; |
151 | |
152 | /// Return the raw attribute component. |
153 | Attribute getAttribute() const { return attr; }; |
154 | |
155 | private: |
156 | friend class ShapeAdaptor; |
157 | |
158 | ShapeStorageT dims; |
159 | Type elementType; |
160 | Attribute attr; |
161 | bool ranked{false}; |
162 | }; |
163 | |
164 | /// Range of values and shapes (corresponding effectively to Shapes dialect's |
165 | /// ValueShape type concept). |
166 | // Currently this exposes the Value (of operands) and Type of the Value. This is |
167 | // not ideal as then one can accidentally reference an out of date shape. This |
168 | // is done to both enable gradual switch and also as OpAdaptor doesn't currently |
169 | // allow returning anything other than Value. |
170 | class ValueShapeRange : public ValueRange::RangeBaseT { |
171 | public: |
172 | using ValueShapeMapFn = function_ref<ShapeAdaptor(Value)>; |
173 | |
174 | ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr, |
175 | ValueShapeMapFn valueToShape = nullptr) |
176 | : RangeBaseT(values), operandShape(operandShape), |
177 | valueToShape(valueToShape) {} |
178 | ValueShapeRange(const std::initializer_list<Value> &values) |
179 | : ValueShapeRange(ValueRange(values)) {} |
180 | |
181 | ValueShapeRange(const ValueShapeRange &) = default; |
182 | |
183 | /// Sets the Value to ShapeAdaptor mapping function and returns this. |
184 | ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) { |
185 | valueToShape = fn; |
186 | return *this; |
187 | } |
188 | |
189 | ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) { |
190 | operandShape = fn; |
191 | return *this; |
192 | } |
193 | |
194 | /// Returns the set Value to ShapeAdaptor mapping function. |
195 | ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; } |
196 | ValueShapeMapFn getOperandShapeMapping() const { return operandShape; } |
197 | |
198 | // Accessors. |
199 | |
200 | /// Returns the types of the values within this range. |
201 | /// Note: This returns only the types of Values in the ValueRange and not a |
202 | /// more refined type. |
203 | using type_iterator = ValueTypeIterator<iterator>; |
204 | using type_range = ValueTypeRange<ValueRange>; |
205 | type_range getTypes() const { return {begin(), end()}; } |
206 | auto getType() const { return getTypes(); } |
207 | |
208 | /// Returns the Values in the ValueRange. |
209 | /// To query the most up to date shape of a Value, query the shape |
210 | /// using getShape below rather than using the type of the Value. |
211 | ValueRange getValues() const { return ValueRange(begin(), end()); }; |
212 | |
213 | /// Returns an argument as shape. If the argument is not constant or not a |
214 | /// shape, then the function returns a nullptr. |
215 | /// This will first query the valueToShape mapping (if set), before querying |
216 | /// the ValueRange. |
217 | ShapeAdaptor getValueAsShape(int index); |
218 | |
219 | /// Returns the shape of index'th operand. |
220 | // TODO: Update so that operator[] references these instead to avoid |
221 | // accidentally refering to less refined shape. |
222 | ShapeAdaptor getShape(int index) const; |
223 | |
224 | /// Returns the shape of the given Value. |
225 | ShapeAdaptor getShape(Value val) const; |
226 | |
227 | private: |
228 | // Mapping from Value to ShapedTypeComponents corresponding to shape of type |
229 | // of Value. |
230 | ValueShapeMapFn operandShape; |
231 | |
232 | // Mapping from Value to ShapedTypeComponents corresponding to constant Value |
233 | // if interpreted as shape. |
234 | ValueShapeMapFn valueToShape; |
235 | }; |
236 | |
237 | namespace detail { |
238 | // Helper function to infer return tensor returns types given element and |
239 | // shape inference function. |
240 | LogicalResult |
241 | inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents, |
242 | SmallVectorImpl<Type> &inferredReturnTypes); |
243 | |
244 | /// Verifies that the inferred result types match the actual result types for |
245 | /// the op. Precondition: op implements InferTypeOpInterface. |
246 | LogicalResult verifyInferredResultTypes(Operation *op); |
247 | } // namespace detail |
248 | |
249 | namespace OpTrait { |
250 | template <typename ConcreteType> |
251 | class InferTensorType; |
252 | } // namespace OpTrait |
253 | } // namespace mlir |
254 | |
255 | /// Include the generated interface declarations. |
256 | #include "mlir/Interfaces/InferTypeOpInterface.h.inc" |
257 | |
258 | namespace mlir { |
259 | namespace OpTrait { |
260 | |
261 | template <typename ConcreteType> |
262 | class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> { |
263 | }; |
264 | |
265 | template <typename ConcreteType> |
266 | class InferShapedTypeOpAdaptor |
267 | : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {}; |
268 | |
269 | /// Tensor type inference trait that constructs a tensor from the inferred |
270 | /// shape and elemental types. |
271 | /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. |
272 | /// Less strict is possible (e.g., implements inferReturnTypeComponents and |
273 | /// these always populates all element types and shapes or fails, but this\ |
274 | /// trait is currently only used where the interfaces are, so keep it |
275 | /// restricted for now). |
276 | template <typename ConcreteType> |
277 | class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {}; |
278 | |
279 | } // namespace OpTrait |
280 | } // namespace mlir |
281 | |
282 | #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_ |
283 | |