1 | //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains the definitions of the infer op interfaces defined in |
10 | // `InferTypeOpInterface.td`. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
15 | #include "mlir/IR/BuiltinTypes.h" |
16 | #include "mlir/IR/Matchers.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | namespace mlir { |
22 | #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" |
23 | } // namespace mlir |
24 | |
25 | LogicalResult |
26 | mlir::reifyResultShapes(OpBuilder &b, Operation *op, |
27 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
28 | auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); |
29 | if (!reifiableOp) |
30 | return failure(); |
31 | LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes); |
32 | #ifndef NDEBUG |
33 | if (failed(result: status)) |
34 | return failure(); |
35 | // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced |
36 | // a correct result. |
37 | int64_t resultIdx = 0; |
38 | for (OpResult result : op->getResults()) { |
39 | auto shapedType = dyn_cast<ShapedType>(result.getType()); |
40 | if (!shapedType) |
41 | continue; |
42 | if (!shapedType.hasRank()) { |
43 | // Nothing to check for unranked shaped values. |
44 | ++resultIdx; |
45 | continue; |
46 | } |
47 | // Assert one OpFoldResult per dimension. |
48 | assert(shapedType.getRank() == |
49 | static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) && |
50 | "incorrect implementation of ReifyRankedShapedTypeOpInterface" ); |
51 | for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) { |
52 | // reifyResultShapes must return: |
53 | // * Attribute for static dimensions |
54 | // * Value for dynamic dimensions |
55 | assert(shapedType.isDynamicDim(dim) == |
56 | reifiedReturnShapes[resultIdx][dim].is<Value>() && |
57 | "incorrect implementation of ReifyRankedShapedTypeOpInterface" ); |
58 | } |
59 | ++resultIdx; |
60 | } |
61 | // Assert that every shaped value result was reified. |
62 | assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) && |
63 | "incorrect implementation of ReifyRankedShapedTypeOpInterface" ); |
64 | #endif // NDEBUG |
65 | return status; |
66 | } |
67 | |
68 | bool ShapeAdaptor::hasRank() const { |
69 | if (val.isNull()) |
70 | return false; |
71 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
72 | return cast<ShapedType>(t).hasRank(); |
73 | if (val.is<Attribute>()) |
74 | return true; |
75 | return val.get<ShapedTypeComponents *>()->hasRank(); |
76 | } |
77 | |
78 | Type ShapeAdaptor::getElementType() const { |
79 | if (val.isNull()) |
80 | return nullptr; |
81 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
82 | return cast<ShapedType>(t).getElementType(); |
83 | if (val.is<Attribute>()) |
84 | return nullptr; |
85 | return val.get<ShapedTypeComponents *>()->getElementType(); |
86 | } |
87 | |
88 | void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const { |
89 | assert(hasRank()); |
90 | if (auto t = llvm::dyn_cast_if_present<Type>(Val: val)) { |
91 | ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape(); |
92 | res.assign(in_start: vals.begin(), in_end: vals.end()); |
93 | } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) { |
94 | auto dattr = cast<DenseIntElementsAttr>(Val&: attr); |
95 | res.clear(); |
96 | res.reserve(N: dattr.size()); |
97 | for (auto it : dattr.getValues<APInt>()) |
98 | res.push_back(it.getSExtValue()); |
99 | } else { |
100 | auto vals = val.get<ShapedTypeComponents *>()->getDims(); |
101 | res.assign(in_start: vals.begin(), in_end: vals.end()); |
102 | } |
103 | } |
104 | |
105 | void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { |
106 | assert(hasRank()); |
107 | res.ranked = true; |
108 | getDims(res&: res.dims); |
109 | } |
110 | |
111 | int64_t ShapeAdaptor::getDimSize(int index) const { |
112 | assert(hasRank()); |
113 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
114 | return cast<ShapedType>(t).getDimSize(index); |
115 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) |
116 | return cast<DenseIntElementsAttr>(attr) |
117 | .getValues<APInt>()[index] |
118 | .getSExtValue(); |
119 | auto *stc = val.get<ShapedTypeComponents *>(); |
120 | return stc->getDims()[index]; |
121 | } |
122 | |
123 | int64_t ShapeAdaptor::getRank() const { |
124 | assert(hasRank()); |
125 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
126 | return cast<ShapedType>(t).getRank(); |
127 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) |
128 | return cast<DenseIntElementsAttr>(Val&: attr).size(); |
129 | return val.get<ShapedTypeComponents *>()->getDims().size(); |
130 | } |
131 | |
132 | bool ShapeAdaptor::hasStaticShape() const { |
133 | if (!hasRank()) |
134 | return false; |
135 | |
136 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
137 | return cast<ShapedType>(t).hasStaticShape(); |
138 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) { |
139 | auto dattr = cast<DenseIntElementsAttr>(Val&: attr); |
140 | for (auto index : dattr.getValues<APInt>()) |
141 | if (ShapedType::isDynamic(index.getSExtValue())) |
142 | return false; |
143 | return true; |
144 | } |
145 | auto *stc = val.get<ShapedTypeComponents *>(); |
146 | return llvm::none_of(stc->getDims(), ShapedType::isDynamic); |
147 | } |
148 | |
149 | int64_t ShapeAdaptor::getNumElements() const { |
150 | assert(hasStaticShape() && "cannot get element count of dynamic shaped type" ); |
151 | |
152 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
153 | return cast<ShapedType>(t).getNumElements(); |
154 | |
155 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) { |
156 | auto dattr = cast<DenseIntElementsAttr>(Val&: attr); |
157 | int64_t num = 1; |
158 | for (auto index : dattr.getValues<APInt>()) { |
159 | num *= index.getZExtValue(); |
160 | assert(num >= 0 && "integer overflow in element count computation" ); |
161 | } |
162 | return num; |
163 | } |
164 | |
165 | auto *stc = val.get<ShapedTypeComponents *>(); |
166 | int64_t num = 1; |
167 | for (int64_t dim : stc->getDims()) { |
168 | num *= dim; |
169 | assert(num >= 0 && "integer overflow in element count computation" ); |
170 | } |
171 | return num; |
172 | } |
173 | |
174 | void ShapeAdaptor::dump() const { |
175 | if (!hasRank()) { |
176 | llvm::errs() << "<<unranked>>\n" ; |
177 | return; |
178 | } |
179 | |
180 | SmallVector<int64_t> dims; |
181 | getDims(res&: dims); |
182 | auto mapped = llvm::map_range(C&: dims, F: [](int64_t dim) -> std::string { |
183 | if (ShapedType::isDynamic(dim)) |
184 | return "?" ; |
185 | return llvm::formatv(Fmt: "{0}" , Vals&: dim).str(); |
186 | }); |
187 | llvm::errs() << "rank = " << getRank() << " dims = [" ; |
188 | llvm::interleave(c: mapped, os&: llvm::errs(), separator: "x" ); |
189 | llvm::errs() << "]\n" ; |
190 | } |
191 | |
192 | ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { |
193 | Value val = operator[](Index: index); |
194 | if (valueToShape) |
195 | if (ShapeAdaptor ret = valueToShape(val)) |
196 | return ret; |
197 | |
198 | DenseIntElementsAttr attr; |
199 | if (!matchPattern(value: val, pattern: m_Constant(bind_value: &attr))) |
200 | return nullptr; |
201 | if (attr.getType().getRank() != 1) |
202 | return nullptr; |
203 | return attr; |
204 | } |
205 | |
206 | ShapeAdaptor ValueShapeRange::getShape(Value val) const { |
207 | if (operandShape) |
208 | if (ShapeAdaptor ret = operandShape(val)) |
209 | return ret; |
210 | return val.getType(); |
211 | } |
212 | |
213 | ShapeAdaptor ValueShapeRange::getShape(int index) const { |
214 | if (index < 0 || static_cast<size_t>(index) >= size()) |
215 | return nullptr; |
216 | return getShape(val: operator[](Index: index)); |
217 | } |
218 | |
219 | LogicalResult mlir::detail::inferReturnTensorTypes( |
220 | ArrayRef<ShapedTypeComponents> retComponents, |
221 | SmallVectorImpl<Type> &inferredReturnTypes) { |
222 | for (const auto &shapeAndType : retComponents) { |
223 | Type elementTy = shapeAndType.getElementType(); |
224 | assert(elementTy && "element type required to construct tensor" ); |
225 | |
226 | Attribute attr = shapeAndType.getAttribute(); |
227 | if (shapeAndType.hasRank()) { |
228 | inferredReturnTypes.push_back( |
229 | RankedTensorType::get(shapeAndType.getDims(), elementTy, attr)); |
230 | } else { |
231 | assert(attr == nullptr && "attribute not supported" ); |
232 | inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); |
233 | } |
234 | } |
235 | return success(); |
236 | } |
237 | |
238 | LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { |
239 | SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes()); |
240 | auto retTypeFn = cast<InferTypeOpInterface>(op); |
241 | auto result = retTypeFn.refineReturnTypes( |
242 | op->getContext(), op->getLoc(), op->getOperands(), |
243 | op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(), |
244 | inferredReturnTypes); |
245 | if (failed(result)) |
246 | op->emitOpError() << "failed to infer returned types" ; |
247 | |
248 | return result; |
249 | } |
250 | |