1//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the definitions of the infer op interfaces defined in
10// `InferTypeOpInterface.td`.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Interfaces/InferTypeOpInterface.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/Matchers.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/InterleavedRange.h"
19
20using namespace mlir;
21
22namespace mlir {
23#include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
24} // namespace mlir
25
26LogicalResult
27mlir::reifyResultShapes(OpBuilder &b, Operation *op,
28 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
29 auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
30 if (!reifiableOp)
31 return failure();
32 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
33#ifndef NDEBUG
34 if (failed(Result: status))
35 return failure();
36 // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
37 // a correct result.
38 int64_t resultIdx = 0;
39 for (OpResult result : op->getResults()) {
40 auto shapedType = dyn_cast<ShapedType>(result.getType());
41 if (!shapedType)
42 continue;
43 if (!shapedType.hasRank()) {
44 // Nothing to check for unranked shaped values.
45 ++resultIdx;
46 continue;
47 }
48 // Assert one OpFoldResult per dimension.
49 assert(shapedType.getRank() ==
50 static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
51 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
52 ++resultIdx;
53 }
54 // Assert that every shaped value result was reified.
55 assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
56 "incorrect implementation of ReifyRankedShapedTypeOpInterface");
57#endif // NDEBUG
58 return status;
59}
60
61bool ShapeAdaptor::hasRank() const {
62 if (val.isNull())
63 return false;
64 if (auto t = llvm::dyn_cast_if_present<Type>(val))
65 return cast<ShapedType>(t).hasRank();
66 if (isa<Attribute>(Val: val))
67 return true;
68 return cast<ShapedTypeComponents *>(Val: val)->hasRank();
69}
70
71Type ShapeAdaptor::getElementType() const {
72 if (val.isNull())
73 return nullptr;
74 if (auto t = llvm::dyn_cast_if_present<Type>(val))
75 return cast<ShapedType>(t).getElementType();
76 if (isa<Attribute>(Val: val))
77 return nullptr;
78 return cast<ShapedTypeComponents *>(Val: val)->getElementType();
79}
80
81void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
82 assert(hasRank());
83 if (auto t = llvm::dyn_cast_if_present<Type>(Val: val)) {
84 ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
85 res.assign(in_start: vals.begin(), in_end: vals.end());
86 } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) {
87 auto dattr = cast<DenseIntElementsAttr>(Val&: attr);
88 res.clear();
89 res.reserve(N: dattr.size());
90 for (auto it : dattr.getValues<APInt>())
91 res.push_back(it.getSExtValue());
92 } else {
93 auto vals = cast<ShapedTypeComponents *>(Val: val)->getDims();
94 res.assign(in_start: vals.begin(), in_end: vals.end());
95 }
96}
97
98void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
99 assert(hasRank());
100 res.ranked = true;
101 getDims(res&: res.dims);
102}
103
104int64_t ShapeAdaptor::getDimSize(int index) const {
105 assert(hasRank());
106 if (auto t = llvm::dyn_cast_if_present<Type>(val))
107 return cast<ShapedType>(t).getDimSize(index);
108 if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
109 return cast<DenseIntElementsAttr>(attr)
110 .getValues<APInt>()[index]
111 .getSExtValue();
112 auto *stc = cast<ShapedTypeComponents *>(Val: val);
113 return stc->getDims()[index];
114}
115
116int64_t ShapeAdaptor::getRank() const {
117 assert(hasRank());
118 if (auto t = llvm::dyn_cast_if_present<Type>(val))
119 return cast<ShapedType>(t).getRank();
120 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val))
121 return cast<DenseIntElementsAttr>(Val&: attr).size();
122 return cast<ShapedTypeComponents *>(Val: val)->getDims().size();
123}
124
125bool ShapeAdaptor::hasStaticShape() const {
126 if (!hasRank())
127 return false;
128
129 if (auto t = llvm::dyn_cast_if_present<Type>(val))
130 return cast<ShapedType>(t).hasStaticShape();
131 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) {
132 auto dattr = cast<DenseIntElementsAttr>(Val&: attr);
133 for (auto index : dattr.getValues<APInt>())
134 if (ShapedType::isDynamic(index.getSExtValue()))
135 return false;
136 return true;
137 }
138 auto *stc = cast<ShapedTypeComponents *>(Val: val);
139 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
140}
141
142int64_t ShapeAdaptor::getNumElements() const {
143 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
144
145 if (auto t = llvm::dyn_cast_if_present<Type>(val))
146 return cast<ShapedType>(t).getNumElements();
147
148 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) {
149 auto dattr = cast<DenseIntElementsAttr>(Val&: attr);
150 int64_t num = 1;
151 for (auto index : dattr.getValues<APInt>()) {
152 num *= index.getZExtValue();
153 assert(num >= 0 && "integer overflow in element count computation");
154 }
155 return num;
156 }
157
158 auto *stc = cast<ShapedTypeComponents *>(Val: val);
159 int64_t num = 1;
160 for (int64_t dim : stc->getDims()) {
161 num *= dim;
162 assert(num >= 0 && "integer overflow in element count computation");
163 }
164 return num;
165}
166
167void ShapeAdaptor::dump() const {
168 if (!hasRank()) {
169 llvm::errs() << "<<unranked>>\n";
170 return;
171 }
172
173 SmallVector<int64_t> dims;
174 getDims(res&: dims);
175 auto mapped = llvm::map_range(C&: dims, F: [](int64_t dim) -> std::string {
176 if (ShapedType::isDynamic(dim))
177 return "?";
178 return llvm::formatv(Fmt: "{0}", Vals&: dim).str();
179 });
180 llvm::errs() << "rank = " << getRank()
181 << " dims = " << llvm::interleaved_array(R: mapped, Separator: "x") << "\n";
182}
183
184ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
185 Value val = operator[](Index: index);
186 if (valueToShape)
187 if (ShapeAdaptor ret = valueToShape(val))
188 return ret;
189
190 DenseIntElementsAttr attr;
191 if (!matchPattern(value: val, pattern: m_Constant(bind_value: &attr)))
192 return nullptr;
193 if (attr.getType().getRank() != 1)
194 return nullptr;
195 return attr;
196}
197
198ShapeAdaptor ValueShapeRange::getShape(Value val) const {
199 if (operandShape)
200 if (ShapeAdaptor ret = operandShape(val))
201 return ret;
202 return val.getType();
203}
204
205ShapeAdaptor ValueShapeRange::getShape(int index) const {
206 if (index < 0 || static_cast<size_t>(index) >= size())
207 return nullptr;
208 return getShape(val: operator[](Index: index));
209}
210
211LogicalResult mlir::detail::inferReturnTensorTypes(
212 ArrayRef<ShapedTypeComponents> retComponents,
213 SmallVectorImpl<Type> &inferredReturnTypes) {
214 for (const auto &shapeAndType : retComponents) {
215 Type elementTy = shapeAndType.getElementType();
216 assert(elementTy && "element type required to construct tensor");
217
218 Attribute attr = shapeAndType.getAttribute();
219 if (shapeAndType.hasRank()) {
220 inferredReturnTypes.push_back(
221 RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
222 } else {
223 assert(attr == nullptr && "attribute not supported");
224 inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
225 }
226 }
227 return success();
228}
229
230LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
231 SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
232 auto retTypeFn = cast<InferTypeOpInterface>(op);
233 auto result = retTypeFn.refineReturnTypes(
234 op->getContext(), op->getLoc(), op->getOperands(),
235 op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(),
236 inferredReturnTypes);
237 if (failed(result))
238 op->emitOpError() << "failed to infer returned types";
239
240 return result;
241}
242
243void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) {
244 std::string buffer;
245 llvm::raw_string_ostream os(buffer);
246 os << "Failed to infer result type(s):\n"
247 << "\"" << state.name << "\"(...) "
248 << state.attributes.getDictionary(state.location.getContext()) << " : ("
249 << llvm::interleaved(R: llvm::map_range(
250 C&: state.operands, F: [](Value val) { return val.getType(); }))
251 << ") -> ( ??? )";
252 emitRemark(loc: state.location, message: "location of op");
253 llvm::report_fatal_error(reason: llvm::StringRef(buffer));
254}
255

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Interfaces/InferTypeOpInterface.cpp