| 1 | //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains the definitions of the infer op interfaces defined in |
| 10 | // `InferTypeOpInterface.td`. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
| 15 | #include "mlir/IR/BuiltinTypes.h" |
| 16 | #include "mlir/IR/Matchers.h" |
| 17 | #include "llvm/Support/FormatVariadic.h" |
| 18 | #include "llvm/Support/InterleavedRange.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | |
| 22 | namespace mlir { |
| 23 | #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" |
| 24 | } // namespace mlir |
| 25 | |
| 26 | LogicalResult |
| 27 | mlir::reifyResultShapes(OpBuilder &b, Operation *op, |
| 28 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| 29 | auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); |
| 30 | if (!reifiableOp) |
| 31 | return failure(); |
| 32 | LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes); |
| 33 | #ifndef NDEBUG |
| 34 | if (failed(Result: status)) |
| 35 | return failure(); |
| 36 | // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced |
| 37 | // a correct result. |
| 38 | int64_t resultIdx = 0; |
| 39 | for (OpResult result : op->getResults()) { |
| 40 | auto shapedType = dyn_cast<ShapedType>(result.getType()); |
| 41 | if (!shapedType) |
| 42 | continue; |
| 43 | if (!shapedType.hasRank()) { |
| 44 | // Nothing to check for unranked shaped values. |
| 45 | ++resultIdx; |
| 46 | continue; |
| 47 | } |
| 48 | // Assert one OpFoldResult per dimension. |
| 49 | assert(shapedType.getRank() == |
| 50 | static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) && |
| 51 | "incorrect implementation of ReifyRankedShapedTypeOpInterface" ); |
| 52 | ++resultIdx; |
| 53 | } |
| 54 | // Assert that every shaped value result was reified. |
| 55 | assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) && |
| 56 | "incorrect implementation of ReifyRankedShapedTypeOpInterface" ); |
| 57 | #endif // NDEBUG |
| 58 | return status; |
| 59 | } |
| 60 | |
| 61 | bool ShapeAdaptor::hasRank() const { |
| 62 | if (val.isNull()) |
| 63 | return false; |
| 64 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| 65 | return cast<ShapedType>(t).hasRank(); |
| 66 | if (isa<Attribute>(Val: val)) |
| 67 | return true; |
| 68 | return cast<ShapedTypeComponents *>(Val: val)->hasRank(); |
| 69 | } |
| 70 | |
| 71 | Type ShapeAdaptor::getElementType() const { |
| 72 | if (val.isNull()) |
| 73 | return nullptr; |
| 74 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| 75 | return cast<ShapedType>(t).getElementType(); |
| 76 | if (isa<Attribute>(Val: val)) |
| 77 | return nullptr; |
| 78 | return cast<ShapedTypeComponents *>(Val: val)->getElementType(); |
| 79 | } |
| 80 | |
| 81 | void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const { |
| 82 | assert(hasRank()); |
| 83 | if (auto t = llvm::dyn_cast_if_present<Type>(Val: val)) { |
| 84 | ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape(); |
| 85 | res.assign(in_start: vals.begin(), in_end: vals.end()); |
| 86 | } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) { |
| 87 | auto dattr = cast<DenseIntElementsAttr>(Val&: attr); |
| 88 | res.clear(); |
| 89 | res.reserve(N: dattr.size()); |
| 90 | for (auto it : dattr.getValues<APInt>()) |
| 91 | res.push_back(it.getSExtValue()); |
| 92 | } else { |
| 93 | auto vals = cast<ShapedTypeComponents *>(Val: val)->getDims(); |
| 94 | res.assign(in_start: vals.begin(), in_end: vals.end()); |
| 95 | } |
| 96 | } |
| 97 | |
| 98 | void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { |
| 99 | assert(hasRank()); |
| 100 | res.ranked = true; |
| 101 | getDims(res&: res.dims); |
| 102 | } |
| 103 | |
| 104 | int64_t ShapeAdaptor::getDimSize(int index) const { |
| 105 | assert(hasRank()); |
| 106 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| 107 | return cast<ShapedType>(t).getDimSize(index); |
| 108 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) |
| 109 | return cast<DenseIntElementsAttr>(attr) |
| 110 | .getValues<APInt>()[index] |
| 111 | .getSExtValue(); |
| 112 | auto *stc = cast<ShapedTypeComponents *>(Val: val); |
| 113 | return stc->getDims()[index]; |
| 114 | } |
| 115 | |
| 116 | int64_t ShapeAdaptor::getRank() const { |
| 117 | assert(hasRank()); |
| 118 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| 119 | return cast<ShapedType>(t).getRank(); |
| 120 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) |
| 121 | return cast<DenseIntElementsAttr>(Val&: attr).size(); |
| 122 | return cast<ShapedTypeComponents *>(Val: val)->getDims().size(); |
| 123 | } |
| 124 | |
| 125 | bool ShapeAdaptor::hasStaticShape() const { |
| 126 | if (!hasRank()) |
| 127 | return false; |
| 128 | |
| 129 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| 130 | return cast<ShapedType>(t).hasStaticShape(); |
| 131 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) { |
| 132 | auto dattr = cast<DenseIntElementsAttr>(Val&: attr); |
| 133 | for (auto index : dattr.getValues<APInt>()) |
| 134 | if (ShapedType::isDynamic(index.getSExtValue())) |
| 135 | return false; |
| 136 | return true; |
| 137 | } |
| 138 | auto *stc = cast<ShapedTypeComponents *>(Val: val); |
| 139 | return llvm::none_of(stc->getDims(), ShapedType::isDynamic); |
| 140 | } |
| 141 | |
| 142 | int64_t ShapeAdaptor::getNumElements() const { |
| 143 | assert(hasStaticShape() && "cannot get element count of dynamic shaped type" ); |
| 144 | |
| 145 | if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| 146 | return cast<ShapedType>(t).getNumElements(); |
| 147 | |
| 148 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val: val)) { |
| 149 | auto dattr = cast<DenseIntElementsAttr>(Val&: attr); |
| 150 | int64_t num = 1; |
| 151 | for (auto index : dattr.getValues<APInt>()) { |
| 152 | num *= index.getZExtValue(); |
| 153 | assert(num >= 0 && "integer overflow in element count computation" ); |
| 154 | } |
| 155 | return num; |
| 156 | } |
| 157 | |
| 158 | auto *stc = cast<ShapedTypeComponents *>(Val: val); |
| 159 | int64_t num = 1; |
| 160 | for (int64_t dim : stc->getDims()) { |
| 161 | num *= dim; |
| 162 | assert(num >= 0 && "integer overflow in element count computation" ); |
| 163 | } |
| 164 | return num; |
| 165 | } |
| 166 | |
| 167 | void ShapeAdaptor::dump() const { |
| 168 | if (!hasRank()) { |
| 169 | llvm::errs() << "<<unranked>>\n" ; |
| 170 | return; |
| 171 | } |
| 172 | |
| 173 | SmallVector<int64_t> dims; |
| 174 | getDims(res&: dims); |
| 175 | auto mapped = llvm::map_range(C&: dims, F: [](int64_t dim) -> std::string { |
| 176 | if (ShapedType::isDynamic(dim)) |
| 177 | return "?" ; |
| 178 | return llvm::formatv(Fmt: "{0}" , Vals&: dim).str(); |
| 179 | }); |
| 180 | llvm::errs() << "rank = " << getRank() |
| 181 | << " dims = " << llvm::interleaved_array(R: mapped, Separator: "x" ) << "\n" ; |
| 182 | } |
| 183 | |
| 184 | ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { |
| 185 | Value val = operator[](Index: index); |
| 186 | if (valueToShape) |
| 187 | if (ShapeAdaptor ret = valueToShape(val)) |
| 188 | return ret; |
| 189 | |
| 190 | DenseIntElementsAttr attr; |
| 191 | if (!matchPattern(value: val, pattern: m_Constant(bind_value: &attr))) |
| 192 | return nullptr; |
| 193 | if (attr.getType().getRank() != 1) |
| 194 | return nullptr; |
| 195 | return attr; |
| 196 | } |
| 197 | |
| 198 | ShapeAdaptor ValueShapeRange::getShape(Value val) const { |
| 199 | if (operandShape) |
| 200 | if (ShapeAdaptor ret = operandShape(val)) |
| 201 | return ret; |
| 202 | return val.getType(); |
| 203 | } |
| 204 | |
| 205 | ShapeAdaptor ValueShapeRange::getShape(int index) const { |
| 206 | if (index < 0 || static_cast<size_t>(index) >= size()) |
| 207 | return nullptr; |
| 208 | return getShape(val: operator[](Index: index)); |
| 209 | } |
| 210 | |
| 211 | LogicalResult mlir::detail::inferReturnTensorTypes( |
| 212 | ArrayRef<ShapedTypeComponents> retComponents, |
| 213 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 214 | for (const auto &shapeAndType : retComponents) { |
| 215 | Type elementTy = shapeAndType.getElementType(); |
| 216 | assert(elementTy && "element type required to construct tensor" ); |
| 217 | |
| 218 | Attribute attr = shapeAndType.getAttribute(); |
| 219 | if (shapeAndType.hasRank()) { |
| 220 | inferredReturnTypes.push_back( |
| 221 | RankedTensorType::get(shapeAndType.getDims(), elementTy, attr)); |
| 222 | } else { |
| 223 | assert(attr == nullptr && "attribute not supported" ); |
| 224 | inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); |
| 225 | } |
| 226 | } |
| 227 | return success(); |
| 228 | } |
| 229 | |
| 230 | LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { |
| 231 | SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes()); |
| 232 | auto retTypeFn = cast<InferTypeOpInterface>(op); |
| 233 | auto result = retTypeFn.refineReturnTypes( |
| 234 | op->getContext(), op->getLoc(), op->getOperands(), |
| 235 | op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(), |
| 236 | inferredReturnTypes); |
| 237 | if (failed(result)) |
| 238 | op->emitOpError() << "failed to infer returned types" ; |
| 239 | |
| 240 | return result; |
| 241 | } |
| 242 | |
| 243 | void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) { |
| 244 | std::string buffer; |
| 245 | llvm::raw_string_ostream os(buffer); |
| 246 | os << "Failed to infer result type(s):\n" |
| 247 | << "\"" << state.name << "\"(...) " |
| 248 | << state.attributes.getDictionary(state.location.getContext()) << " : (" |
| 249 | << llvm::interleaved(R: llvm::map_range( |
| 250 | C&: state.operands, F: [](Value val) { return val.getType(); })) |
| 251 | << ") -> ( ??? )" ; |
| 252 | emitRemark(loc: state.location, message: "location of op" ); |
| 253 | llvm::report_fatal_error(reason: llvm::StringRef(buffer)); |
| 254 | } |
| 255 | |