1//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the definitions of the infer op interfaces defined in
10// `InferTypeOpInterface.td`.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15#define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/Location.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/PointerUnion.h"
24#include "llvm/ADT/SmallVector.h"
25
26namespace mlir {
27
28class ShapedTypeComponents;
29using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
30
31/// Reify the shape of the result of an operation (typically in terms of the
32/// shape of its operands).
33LogicalResult
34reifyResultShapes(OpBuilder &b, Operation *op,
35 ReifiedRankedShapedTypeDims &reifiedReturnShapes);
36
37/// Adaptor class to abstract the differences between whether value is from
38/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
39class ShapeAdaptor {
40public:
41 ShapeAdaptor(Type t) {
42 if (auto st = dyn_cast<ShapedType>(t))
43 val = st;
44 }
45 ShapeAdaptor(Attribute t) {
46 if (auto da = dyn_cast<DenseIntElementsAttr>(t))
47 val = da;
48 }
49 ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
50 ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
51
52 /// Returns whether the shape has a rank.
53 bool hasRank() const;
54
55 /// Returns the element type.
56 Type getElementType() const;
57
58 /// Populates the dimensions from shape referenced.
59 /// Requires: shape is ranked.
60 void getDims(SmallVectorImpl<int64_t> &res) const;
61
62 /// Populates the dimensions of the ShapeTypeComponents.
63 /// Requires: shape is ranked.
64 void getDims(ShapedTypeComponents &res) const;
65
66 /// Returns the size of the index'th dimension.
67 /// Requires: shape is ranked.
68 int64_t getDimSize(int index) const;
69
70 /// Returns whether the index'th dimension is dynamic.
71 /// Requires: shape is ranked.
72 bool isDynamicDim(int index) const {
73 return ShapedType::isDynamic(getDimSize(index));
74 }
75
76 /// Returns whether the shape is fully static.
77 bool hasStaticShape() const;
78
79 /// Returns the rank of the shape.
80 /// Requires: shape is ranked.
81 int64_t getRank() const;
82
83 /// Returns the number of elements in the shape.
84 /// Requires: hasStaticShape
85 int64_t getNumElements() const;
86
87 /// Returns whether valid (non-null) shape.
88 explicit operator bool() const { return !val.isNull(); }
89
90 /// Dumps textual repesentation to stderr.
91 void dump() const;
92
93private:
94 // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
95 // casted), or DenseIntElementsAttribute (stored as Atrtribute).
96 PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
97};
98
99/// ShapedTypeComponents that represents the components of a ShapedType.
100/// The components consist of
101/// - A ranked or unranked shape with the dimension specification match those
102/// of ShapeType's getShape() (e.g., dynamic dimension represented using
103/// ShapedType::kDynamic)
104/// - A element type, may be unset (nullptr)
105/// - A attribute, may be unset (nullptr)
106/// Used by ShapedType type inferences.
107class ShapedTypeComponents {
108 /// Internal storage type for shape.
109 using ShapeStorageT = SmallVector<int64_t, 3>;
110
111public:
112 /// Default construction is an unranked shape.
113 ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
114 ShapedTypeComponents(Type elementType)
115 : elementType(elementType), attr(nullptr), ranked(false) {}
116 ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
117 ranked = shapedType.hasRank();
118 elementType = shapedType.getElementType();
119 if (ranked)
120 dims = llvm::to_vector<4>(shapedType.getShape());
121 }
122 ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
123 ranked = adaptor.hasRank();
124 elementType = adaptor.getElementType();
125 if (ranked)
126 adaptor.getDims(res&: *this);
127 }
128 template <typename Arg, typename = std::enable_if_t<
129 std::is_constructible<ShapeStorageT, Arg>::value>>
130 ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
131 Attribute attr = nullptr)
132 : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
133 ranked(true) {}
134 ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
135 Attribute attr = nullptr)
136 : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
137 ranked(true) {}
138
139 /// Return the dimensions of the shape.
140 /// Requires: shape is ranked.
141 ArrayRef<int64_t> getDims() const {
142 assert(ranked && "requires ranked shape");
143 return dims;
144 }
145
146 /// Return whether the shape has a rank.
147 bool hasRank() const { return ranked; };
148
149 /// Return the element type component.
150 Type getElementType() const { return elementType; };
151
152 /// Return the raw attribute component.
153 Attribute getAttribute() const { return attr; };
154
155private:
156 friend class ShapeAdaptor;
157
158 ShapeStorageT dims;
159 Type elementType;
160 Attribute attr;
161 bool ranked{false};
162};
163
164/// Range of values and shapes (corresponding effectively to Shapes dialect's
165/// ValueShape type concept).
166// Currently this exposes the Value (of operands) and Type of the Value. This is
167// not ideal as then one can accidentally reference an out of date shape. This
168// is done to both enable gradual switch and also as OpAdaptor doesn't currently
169// allow returning anything other than Value.
170class ValueShapeRange : public ValueRange::RangeBaseT {
171public:
172 using ValueShapeMapFn = function_ref<ShapeAdaptor(Value)>;
173
174 ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
175 ValueShapeMapFn valueToShape = nullptr)
176 : RangeBaseT(values), operandShape(operandShape),
177 valueToShape(valueToShape) {}
178 ValueShapeRange(const std::initializer_list<Value> &values)
179 : ValueShapeRange(ValueRange(values)) {}
180
181 ValueShapeRange(const ValueShapeRange &) = default;
182
183 /// Sets the Value to ShapeAdaptor mapping function and returns this.
184 ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) {
185 valueToShape = fn;
186 return *this;
187 }
188
189 ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) {
190 operandShape = fn;
191 return *this;
192 }
193
194 /// Returns the set Value to ShapeAdaptor mapping function.
195 ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
196 ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
197
198 // Accessors.
199
200 /// Returns the types of the values within this range.
201 /// Note: This returns only the types of Values in the ValueRange and not a
202 /// more refined type.
203 using type_iterator = ValueTypeIterator<iterator>;
204 using type_range = ValueTypeRange<ValueRange>;
205 type_range getTypes() const { return {begin(), end()}; }
206 auto getType() const { return getTypes(); }
207
208 /// Returns the Values in the ValueRange.
209 /// To query the most up to date shape of a Value, query the shape
210 /// using getShape below rather than using the type of the Value.
211 ValueRange getValues() const { return ValueRange(begin(), end()); };
212
213 /// Returns an argument as shape. If the argument is not constant or not a
214 /// shape, then the function returns a nullptr.
215 /// This will first query the valueToShape mapping (if set), before querying
216 /// the ValueRange.
217 ShapeAdaptor getValueAsShape(int index);
218
219 /// Returns the shape of index'th operand.
220 // TODO: Update so that operator[] references these instead to avoid
221 // accidentally refering to less refined shape.
222 ShapeAdaptor getShape(int index) const;
223
224 /// Returns the shape of the given Value.
225 ShapeAdaptor getShape(Value val) const;
226
227private:
228 // Mapping from Value to ShapedTypeComponents corresponding to shape of type
229 // of Value.
230 ValueShapeMapFn operandShape;
231
232 // Mapping from Value to ShapedTypeComponents corresponding to constant Value
233 // if interpreted as shape.
234 ValueShapeMapFn valueToShape;
235};
236
237namespace detail {
238// Helper function to infer return tensor returns types given element and
239// shape inference function.
240LogicalResult
241inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents,
242 SmallVectorImpl<Type> &inferredReturnTypes);
243
244/// Verifies that the inferred result types match the actual result types for
245/// the op. Precondition: op implements InferTypeOpInterface.
246LogicalResult verifyInferredResultTypes(Operation *op);
247} // namespace detail
248
249namespace OpTrait {
250template <typename ConcreteType>
251class InferTensorType;
252} // namespace OpTrait
253} // namespace mlir
254
255/// Include the generated interface declarations.
256#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
257
258namespace mlir {
259namespace OpTrait {
260
261template <typename ConcreteType>
262class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
263};
264
265template <typename ConcreteType>
266class InferShapedTypeOpAdaptor
267 : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
268
269/// Tensor type inference trait that constructs a tensor from the inferred
270/// shape and elemental types.
271/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
272/// Less strict is possible (e.g., implements inferReturnTypeComponents and
273/// these always populates all element types and shapes or fails, but this\
274/// trait is currently only used where the interfaces are, so keep it
275/// restricted for now).
276template <typename ConcreteType>
277class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {};
278
279} // namespace OpTrait
280} // namespace mlir
281
282#endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
283

source code of mlir/include/mlir/Interfaces/InferTypeOpInterface.h