1//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the definitions of the infer op interfaces defined in
10// `InferTypeOpInterface.td`.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Interfaces/InferTypeOpInterface.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/Matchers.h"
17#include "llvm/Support/FormatVariadic.h"
18
19using namespace mlir;
20
21namespace mlir {
22#include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
23} // namespace mlir
24
25LogicalResult
26mlir::reifyResultShapes(OpBuilder &b, Operation *op,
27 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
28 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
29 if (!reifiableOp)
30 return failure();
31 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
32#ifndef NDEBUG
33 if (failed(result: status))
34 return failure();
35 // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
36 // a correct result.
37 int64_t resultIdx = 0;
38 for (OpResult result : op->getResults()) {
39 auto shapedType = dyn_cast<ShapedType>(result.getType());
40 if (!shapedType)
41 continue;
42 if (!shapedType.hasRank()) {
43 // Nothing to check for unranked shaped values.
44 ++resultIdx;
45 continue;
46 }
47 // Assert one OpFoldResult per dimension.
48 assert(shapedType.getRank() ==
49 static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
50 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
51 for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
52 // reifyResultShapes must return:
53 // * Attribute for static dimensions
54 // * Value for dynamic dimensions
55 assert(shapedType.isDynamicDim(dim) ==
56 reifiedReturnShapes[resultIdx][dim].is<Value>() &&
57 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
58 }
59 ++resultIdx;
60 }
61 // Assert that every shaped value result was reified.
62 assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
63 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
64#endif // NDEBUG
65 return status;
66}
67
68bool ShapeAdaptor::hasRank() const {
69 if (val.isNull())
70 return false;
71 if (auto t = llvm::dyn_cast_if_present<Type>(val))
72 return cast<ShapedType>(t).hasRank();
73 if (val.is<Attribute>())
74 return true;
75 return val.get<ShapedTypeComponents *>()->hasRank();
76}
77
78Type ShapeAdaptor::getElementType() const {
79 if (val.isNull())
80 return nullptr;
81 if (auto t = llvm::dyn_cast_if_present<Type>(val))
82 return cast<ShapedType>(t).getElementType();
83 if (val.is<Attribute>())
84 return nullptr;
85 return val.get<ShapedTypeComponents *>()->getElementType();
86}
87
88void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
89 assert(hasRank());
90 if (auto t = llvm::dyn_cast_if_present<Type>(Val: val)) {
91 ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
92 res.assign(in_start: vals.begin(), in_end: vals.end());
93 } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) {
94 auto dattr = cast<DenseIntElementsAttr>(Val&: attr);
95 res.clear();
96 res.reserve(N: dattr.size());
97 for (auto it : dattr.getValues<APInt>())
98 res.push_back(it.getSExtValue());
99 } else {
100 auto vals = val.get<ShapedTypeComponents *>()->getDims();
101 res.assign(in_start: vals.begin(), in_end: vals.end());
102 }
103}
104
105void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
106 assert(hasRank());
107 res.ranked = true;
108 getDims(res&: res.dims);
109}
110
111int64_t ShapeAdaptor::getDimSize(int index) const {
112 assert(hasRank());
113 if (auto t = llvm::dyn_cast_if_present<Type>(val))
114 return cast<ShapedType>(t).getDimSize(index);
115 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
116 return cast<DenseIntElementsAttr>(attr)
117 .getValues<APInt>()[index]
118 .getSExtValue();
119 auto *stc = val.get<ShapedTypeComponents *>();
120 return stc->getDims()[index];
121}
122
123int64_t ShapeAdaptor::getRank() const {
124 assert(hasRank());
125 if (auto t = llvm::dyn_cast_if_present<Type>(val))
126 return cast<ShapedType>(t).getRank();
127 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val))
128 return cast<DenseIntElementsAttr>(Val&: attr).size();
129 return val.get<ShapedTypeComponents *>()->getDims().size();
130}
131
132bool ShapeAdaptor::hasStaticShape() const {
133 if (!hasRank())
134 return false;
135
136 if (auto t = llvm::dyn_cast_if_present<Type>(val))
137 return cast<ShapedType>(t).hasStaticShape();
138 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) {
139 auto dattr = cast<DenseIntElementsAttr>(Val&: attr);
140 for (auto index : dattr.getValues<APInt>())
141 if (ShapedType::isDynamic(index.getSExtValue()))
142 return false;
143 return true;
144 }
145 auto *stc = val.get<ShapedTypeComponents *>();
146 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
147}
148
149int64_t ShapeAdaptor::getNumElements() const {
150 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
151
152 if (auto t = llvm::dyn_cast_if_present<Type>(val))
153 return cast<ShapedType>(t).getNumElements();
154
155 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) {
156 auto dattr = cast<DenseIntElementsAttr>(Val&: attr);
157 int64_t num = 1;
158 for (auto index : dattr.getValues<APInt>()) {
159 num *= index.getZExtValue();
160 assert(num >= 0 && "integer overflow in element count computation");
161 }
162 return num;
163 }
164
165 auto *stc = val.get<ShapedTypeComponents *>();
166 int64_t num = 1;
167 for (int64_t dim : stc->getDims()) {
168 num *= dim;
169 assert(num >= 0 && "integer overflow in element count computation");
170 }
171 return num;
172}
173
174void ShapeAdaptor::dump() const {
175 if (!hasRank()) {
176 llvm::errs() << "<<unranked>>\n";
177 return;
178 }
179
180 SmallVector<int64_t> dims;
181 getDims(res&: dims);
182 auto mapped = llvm::map_range(C&: dims, F: [](int64_t dim) -> std::string {
183 if (ShapedType::isDynamic(dim))
184 return "?";
185 return llvm::formatv(Fmt: "{0}", Vals&: dim).str();
186 });
187 llvm::errs() << "rank = " << getRank() << " dims = [";
188 llvm::interleave(c: mapped, os&: llvm::errs(), separator: "x");
189 llvm::errs() << "]\n";
190}
191
192ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
193 Value val = operator[](Index: index);
194 if (valueToShape)
195 if (ShapeAdaptor ret = valueToShape(val))
196 return ret;
197
198 DenseIntElementsAttr attr;
199 if (!matchPattern(value: val, pattern: m_Constant(bind_value: &attr)))
200 return nullptr;
201 if (attr.getType().getRank() != 1)
202 return nullptr;
203 return attr;
204}
205
206ShapeAdaptor ValueShapeRange::getShape(Value val) const {
207 if (operandShape)
208 if (ShapeAdaptor ret = operandShape(val))
209 return ret;
210 return val.getType();
211}
212
213ShapeAdaptor ValueShapeRange::getShape(int index) const {
214 if (index < 0 || static_cast<size_t>(index) >= size())
215 return nullptr;
216 return getShape(val: operator[](Index: index));
217}
218
219LogicalResult mlir::detail::inferReturnTensorTypes(
220 ArrayRef<ShapedTypeComponents> retComponents,
221 SmallVectorImpl<Type> &inferredReturnTypes) {
222 for (const auto &shapeAndType : retComponents) {
223 Type elementTy = shapeAndType.getElementType();
224 assert(elementTy && "element type required to construct tensor");
225
226 Attribute attr = shapeAndType.getAttribute();
227 if (shapeAndType.hasRank()) {
228 inferredReturnTypes.push_back(
229 RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
230 } else {
231 assert(attr == nullptr && "attribute not supported");
232 inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
233 }
234 }
235 return success();
236}
237
238LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
239 SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
240 auto retTypeFn = cast<InferTypeOpInterface>(op);
241 auto result = retTypeFn.refineReturnTypes(
242 op->getContext(), op->getLoc(), op->getOperands(),
243 op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(),
244 inferredReturnTypes);
245 if (failed(result))
246 op->emitOpError() << "failed to infer returned types";
247
248 return result;
249}
250

source code of mlir/lib/Interfaces/InferTypeOpInterface.cpp