| 1 | //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Interfaces/ViewLikeInterface.h" |
| 10 | |
| 11 | using namespace mlir; |
| 12 | |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | // ViewLike Interfaces |
| 15 | //===----------------------------------------------------------------------===// |
| 16 | |
| 17 | /// Include the definitions of the loop-like interfaces. |
| 18 | #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" |
| 19 | |
| 20 | LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, |
| 21 | StringRef name, |
| 22 | unsigned numElements, |
| 23 | ArrayRef<int64_t> staticVals, |
| 24 | ValueRange values) { |
| 25 | // Check static and dynamic offsets/sizes/strides does not overflow type. |
| 26 | if (staticVals.size() != numElements) |
| 27 | return op->emitError(message: "expected " ) << numElements << " " << name |
| 28 | << " values, got " << staticVals.size(); |
| 29 | unsigned expectedNumDynamicEntries = |
| 30 | llvm::count_if(staticVals, ShapedType::isDynamic); |
| 31 | if (values.size() != expectedNumDynamicEntries) |
| 32 | return op->emitError(message: "expected " ) |
| 33 | << expectedNumDynamicEntries << " dynamic " << name << " values" ; |
| 34 | return success(); |
| 35 | } |
| 36 | |
| 37 | SliceBoundsVerificationResult mlir::verifyInBoundsSlice( |
| 38 | ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets, |
| 39 | ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, |
| 40 | bool generateErrorMessage) { |
| 41 | SliceBoundsVerificationResult result; |
| 42 | result.isValid = true; |
| 43 | for (int64_t i = 0, e = shape.size(); i < e; ++i) { |
| 44 | // Nothing to verify for dynamic source dims. |
| 45 | if (ShapedType::isDynamic(shape[i])) |
| 46 | continue; |
| 47 | // Nothing to verify if the offset is dynamic. |
| 48 | if (ShapedType::isDynamic(staticOffsets[i])) |
| 49 | continue; |
| 50 | if (staticOffsets[i] >= shape[i]) { |
| 51 | result.errorMessage = |
| 52 | std::string("offset " ) + std::to_string(val: i) + |
| 53 | " is out-of-bounds: " + std::to_string(val: staticOffsets[i]) + |
| 54 | " >= " + std::to_string(val: shape[i]); |
| 55 | result.isValid = false; |
| 56 | return result; |
| 57 | } |
| 58 | if (ShapedType::isDynamic(staticSizes[i]) || |
| 59 | ShapedType::isDynamic(staticStrides[i])) |
| 60 | continue; |
| 61 | int64_t lastPos = |
| 62 | staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i]; |
| 63 | if (lastPos >= shape[i]) { |
| 64 | result.errorMessage = std::string("slice along dimension " ) + |
| 65 | std::to_string(val: i) + |
| 66 | " runs out-of-bounds: " + std::to_string(val: lastPos) + |
| 67 | " >= " + std::to_string(val: shape[i]); |
| 68 | result.isValid = false; |
| 69 | return result; |
| 70 | } |
| 71 | } |
| 72 | return result; |
| 73 | } |
| 74 | |
| 75 | SliceBoundsVerificationResult mlir::verifyInBoundsSlice( |
| 76 | ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets, |
| 77 | ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides, |
| 78 | bool generateErrorMessage) { |
| 79 | auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) { |
| 80 | SmallVector<int64_t> staticValues; |
| 81 | for (OpFoldResult ofr : ofrs) { |
| 82 | if (auto attr = dyn_cast<Attribute>(Val&: ofr)) { |
| 83 | staticValues.push_back(Elt: cast<IntegerAttr>(attr).getInt()); |
| 84 | } else { |
| 85 | staticValues.push_back(ShapedType::kDynamic); |
| 86 | } |
| 87 | } |
| 88 | return staticValues; |
| 89 | }; |
| 90 | return verifyInBoundsSlice( |
| 91 | shape, staticOffsets: getStaticValues(mixedOffsets), staticSizes: getStaticValues(mixedSizes), |
| 92 | staticStrides: getStaticValues(mixedStrides), generateErrorMessage); |
| 93 | } |
| 94 | |
| 95 | LogicalResult |
| 96 | mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { |
| 97 | std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); |
| 98 | // Offsets can come in 2 flavors: |
| 99 | // 1. Either single entry (when maxRanks == 1). |
| 100 | // 2. Or as an array whose rank must match that of the mixed sizes. |
| 101 | // So that the result type is well-formed. |
| 102 | if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT |
| 103 | op.getMixedOffsets().size() != op.getMixedSizes().size()) |
| 104 | return op->emitError( |
| 105 | "expected mixed offsets rank to match mixed sizes rank (" ) |
| 106 | << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() |
| 107 | << ") so the rank of the result type is well-formed." ; |
| 108 | // Ranks of mixed sizes and strides must always match so the result type is |
| 109 | // well-formed. |
| 110 | if (op.getMixedSizes().size() != op.getMixedStrides().size()) |
| 111 | return op->emitError( |
| 112 | "expected mixed sizes rank to match mixed strides rank (" ) |
| 113 | << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() |
| 114 | << ") so the rank of the result type is well-formed." ; |
| 115 | |
| 116 | if (failed(verifyListOfOperandsOrIntegers( |
| 117 | op, "offset" , maxRanks[0], op.getStaticOffsets(), op.getOffsets()))) |
| 118 | return failure(); |
| 119 | if (failed(verifyListOfOperandsOrIntegers( |
| 120 | op, "size" , maxRanks[1], op.getStaticSizes(), op.getSizes()))) |
| 121 | return failure(); |
| 122 | if (failed(verifyListOfOperandsOrIntegers( |
| 123 | op, "stride" , maxRanks[2], op.getStaticStrides(), op.getStrides()))) |
| 124 | return failure(); |
| 125 | |
| 126 | for (int64_t offset : op.getStaticOffsets()) { |
| 127 | if (offset < 0 && !ShapedType::isDynamic(offset)) |
| 128 | return op->emitError("expected offsets to be non-negative, but got " ) |
| 129 | << offset; |
| 130 | } |
| 131 | for (int64_t size : op.getStaticSizes()) { |
| 132 | if (size < 0 && !ShapedType::isDynamic(size)) |
| 133 | return op->emitError("expected sizes to be non-negative, but got " ) |
| 134 | << size; |
| 135 | } |
| 136 | return success(); |
| 137 | } |
| 138 | |
| 139 | static char getLeftDelimiter(AsmParser::Delimiter delimiter) { |
| 140 | switch (delimiter) { |
| 141 | case AsmParser::Delimiter::Paren: |
| 142 | return '('; |
| 143 | case AsmParser::Delimiter::LessGreater: |
| 144 | return '<'; |
| 145 | case AsmParser::Delimiter::Square: |
| 146 | return '['; |
| 147 | case AsmParser::Delimiter::Braces: |
| 148 | return '{'; |
| 149 | default: |
| 150 | llvm_unreachable("unsupported delimiter" ); |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | static char getRightDelimiter(AsmParser::Delimiter delimiter) { |
| 155 | switch (delimiter) { |
| 156 | case AsmParser::Delimiter::Paren: |
| 157 | return ')'; |
| 158 | case AsmParser::Delimiter::LessGreater: |
| 159 | return '>'; |
| 160 | case AsmParser::Delimiter::Square: |
| 161 | return ']'; |
| 162 | case AsmParser::Delimiter::Braces: |
| 163 | return '}'; |
| 164 | default: |
| 165 | llvm_unreachable("unsupported delimiter" ); |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, |
| 170 | OperandRange values, |
| 171 | ArrayRef<int64_t> integers, |
| 172 | ArrayRef<bool> scalableFlags, |
| 173 | TypeRange valueTypes, |
| 174 | AsmParser::Delimiter delimiter) { |
| 175 | char leftDelimiter = getLeftDelimiter(delimiter); |
| 176 | char rightDelimiter = getRightDelimiter(delimiter); |
| 177 | printer << leftDelimiter; |
| 178 | if (integers.empty()) { |
| 179 | printer << rightDelimiter; |
| 180 | return; |
| 181 | } |
| 182 | |
| 183 | unsigned dynamicValIdx = 0; |
| 184 | unsigned scalableIndexIdx = 0; |
| 185 | llvm::interleaveComma(c: integers, os&: printer, each_fn: [&](int64_t integer) { |
| 186 | if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) |
| 187 | printer << "[" ; |
| 188 | if (ShapedType::isDynamic(integer)) { |
| 189 | printer << values[dynamicValIdx]; |
| 190 | if (!valueTypes.empty()) |
| 191 | printer << " : " << valueTypes[dynamicValIdx]; |
| 192 | ++dynamicValIdx; |
| 193 | } else { |
| 194 | printer << integer; |
| 195 | } |
| 196 | if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) |
| 197 | printer << "]" ; |
| 198 | |
| 199 | scalableIndexIdx++; |
| 200 | }); |
| 201 | |
| 202 | printer << rightDelimiter; |
| 203 | } |
| 204 | |
| 205 | ParseResult mlir::parseDynamicIndexList( |
| 206 | OpAsmParser &parser, |
| 207 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
| 208 | DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, |
| 209 | SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) { |
| 210 | |
| 211 | SmallVector<int64_t, 4> integerVals; |
| 212 | SmallVector<bool, 4> scalableVals; |
| 213 | auto parseIntegerOrValue = [&]() { |
| 214 | OpAsmParser::UnresolvedOperand operand; |
| 215 | auto res = parser.parseOptionalOperand(result&: operand); |
| 216 | |
| 217 | // When encountering `[`, assume that this is a scalable index. |
| 218 | scalableVals.push_back(Elt: parser.parseOptionalLSquare().succeeded()); |
| 219 | |
| 220 | if (res.has_value() && succeeded(Result: res.value())) { |
| 221 | values.push_back(Elt: operand); |
| 222 | integerVals.push_back(ShapedType::kDynamic); |
| 223 | if (valueTypes && parser.parseColonType(result&: valueTypes->emplace_back())) |
| 224 | return failure(); |
| 225 | } else { |
| 226 | int64_t integer; |
| 227 | if (failed(Result: parser.parseInteger(result&: integer))) |
| 228 | return failure(); |
| 229 | integerVals.push_back(Elt: integer); |
| 230 | } |
| 231 | |
| 232 | // If this is assumed to be a scalable index, verify that there's a closing |
| 233 | // `]`. |
| 234 | if (scalableVals.back() && parser.parseOptionalRSquare().failed()) |
| 235 | return failure(); |
| 236 | return success(); |
| 237 | }; |
| 238 | if (parser.parseCommaSeparatedList(delimiter, parseElementFn: parseIntegerOrValue, |
| 239 | contextMessage: " in dynamic index list" )) |
| 240 | return parser.emitError(loc: parser.getNameLoc()) |
| 241 | << "expected SSA value or integer" ; |
| 242 | integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); |
| 243 | scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); |
| 244 | return success(); |
| 245 | } |
| 246 | |
| 247 | bool mlir::detail::sameOffsetsSizesAndStrides( |
| 248 | OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, |
| 249 | llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { |
| 250 | if (a.getStaticOffsets().size() != b.getStaticOffsets().size()) |
| 251 | return false; |
| 252 | if (a.getStaticSizes().size() != b.getStaticSizes().size()) |
| 253 | return false; |
| 254 | if (a.getStaticStrides().size() != b.getStaticStrides().size()) |
| 255 | return false; |
| 256 | for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) |
| 257 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
| 258 | return false; |
| 259 | for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) |
| 260 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
| 261 | return false; |
| 262 | for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) |
| 263 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
| 264 | return false; |
| 265 | return true; |
| 266 | } |
| 267 | |
| 268 | unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals, |
| 269 | unsigned idx) { |
| 270 | return std::count_if(staticVals.begin(), staticVals.begin() + idx, |
| 271 | ShapedType::isDynamic); |
| 272 | } |
| 273 | |