1 | //===- TypeUtilities.h - Helper function for type queries -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines generic type utilities. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_TYPEUTILITIES_H |
14 | #define MLIR_IR_TYPEUTILITIES_H |
15 | |
16 | #include "mlir/IR/Operation.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | |
19 | namespace mlir { |
20 | |
21 | class Attribute; |
22 | class TupleType; |
23 | class Type; |
24 | class TypeRange; |
25 | class Value; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // Utility Functions |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | /// Return the element type or return the type itself. |
32 | Type getElementTypeOrSelf(Type type); |
33 | |
34 | /// Return the element type or return the type itself. |
35 | Type getElementTypeOrSelf(Attribute attr); |
36 | Type getElementTypeOrSelf(Value val); |
37 | |
38 | /// Get the types within a nested Tuple. A helper for the class method that |
39 | /// handles storage concerns, which is tricky to do in tablegen. |
40 | SmallVector<Type, 10> getFlattenedTypes(TupleType t); |
41 | |
42 | /// Return true if the specified type is an opaque type with the specified |
43 | /// dialect and typeData. |
44 | bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData); |
45 | |
46 | /// Returns success if the given two shapes are compatible. That is, they have |
47 | /// the same size and each pair of the elements are equal or one of them is |
48 | /// dynamic. |
49 | LogicalResult verifyCompatibleShape(ArrayRef<int64_t> shape1, |
50 | ArrayRef<int64_t> shape2); |
51 | |
52 | /// Returns success if the given two types have compatible shape. That is, |
53 | /// they are both scalars (not shaped), or they are both shaped types and at |
54 | /// least one is unranked or they have compatible dimensions. Dimensions are |
55 | /// compatible if at least one is dynamic or both are equal. The element type |
56 | /// does not matter. |
57 | LogicalResult verifyCompatibleShape(Type type1, Type type2); |
58 | |
59 | /// Returns success if the given two arrays have the same number of elements and |
60 | /// each pair wise entries have compatible shape. |
61 | LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2); |
62 | |
63 | /// Returns success if all given types have compatible shapes. That is, they are |
64 | /// all scalars (not shaped), or they are all shaped types and any ranked shapes |
65 | /// have compatible dimensions. The element type does not matter. |
66 | LogicalResult verifyCompatibleShapes(TypeRange types); |
67 | |
68 | /// Dimensions are compatible if all non-dynamic dims are equal. |
69 | LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims); |
70 | |
71 | /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any |
72 | /// types are inserted, `storage` is used to hold the new type list. The new |
73 | /// type list is returned. `indices` must be sorted by increasing index. |
74 | TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices, |
75 | TypeRange newTypes, SmallVectorImpl<Type> &storage); |
76 | |
77 | /// Filters out any elements referenced by `indices`. If any types are removed, |
78 | /// `storage` is used to hold the new type list. Returns the new type list. |
79 | TypeRange filterTypesOut(TypeRange types, const BitVector &indices, |
80 | SmallVectorImpl<Type> &storage); |
81 | |
82 | //===----------------------------------------------------------------------===// |
83 | // Utility Iterators |
84 | //===----------------------------------------------------------------------===// |
85 | |
86 | // An iterator for the element types of an op's operands of shaped types. |
87 | class OperandElementTypeIterator final |
88 | : public llvm::mapped_iterator_base<OperandElementTypeIterator, |
89 | Operation::operand_iterator, Type> { |
90 | public: |
91 | using BaseT::BaseT; |
92 | |
93 | /// Map the element to the iterator result type. |
94 | Type mapElement(Value value) const; |
95 | }; |
96 | |
97 | using OperandElementTypeRange = iterator_range<OperandElementTypeIterator>; |
98 | |
99 | // An iterator for the tensor element types of an op's results of shaped types. |
100 | class ResultElementTypeIterator final |
101 | : public llvm::mapped_iterator_base<ResultElementTypeIterator, |
102 | Operation::result_iterator, Type> { |
103 | public: |
104 | using BaseT::BaseT; |
105 | |
106 | /// Map the element to the iterator result type. |
107 | Type mapElement(Value value) const; |
108 | }; |
109 | |
110 | using ResultElementTypeRange = iterator_range<ResultElementTypeIterator>; |
111 | |
112 | } // namespace mlir |
113 | |
114 | #endif // MLIR_IR_TYPEUTILITIES_H |
115 | |